-
Notifications
You must be signed in to change notification settings - Fork 362
slight code reorg and bug correction for cross_compile #3472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
apbose
wants to merge
2
commits into
main
Choose a base branch
from
cross_compile_code_reorg_and_corr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+71
−49
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,23 +22,23 @@ | |
|
||
def export( | ||
gm: torch.fx.GraphModule, | ||
cross_compile_flag: Optional[bool] = False, | ||
cross_compile_module: Optional[bool] = False, | ||
) -> ExportedProgram: | ||
"""Export the result of TensorRT compilation into the desired output format. | ||
Arguments: | ||
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` | ||
inputs (torch.Tensor): Torch input tensors | ||
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not | ||
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not | ||
""" | ||
patched_module = transform(gm, cross_compile_flag) | ||
patched_module = transform(gm, cross_compile_module) | ||
exp_program = create_trt_exp_program(patched_module) | ||
return exp_program | ||
|
||
|
||
def transform( | ||
gm: torch.fx.GraphModule, | ||
cross_compile_flag: Optional[bool] = False, | ||
cross_compile_module: Optional[bool] = False, | ||
) -> torch.fx.GraphModule: | ||
""" | ||
Transforms the graphmodule by inlining Pytorch and TensorRT submodules. | ||
|
@@ -48,7 +48,7 @@ def transform( | |
Arguments: | ||
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` | ||
inputs (torch.Tensor): Torch input tensors | ||
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not | ||
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not | ||
Returns an inlined torch.fx.GraphModule | ||
""" | ||
|
@@ -57,7 +57,7 @@ def transform( | |
gm = copy.deepcopy(gm) | ||
|
||
# Inline TensorRT submodules | ||
inline_trt_modules(gm, cross_compile_flag) | ||
inline_trt_modules(gm, cross_compile_module) | ||
|
||
# Inline pytorch submodules | ||
inline_torch_modules(gm) | ||
|
@@ -356,7 +356,7 @@ def create_trt_exp_program( | |
|
||
|
||
def inline_trt_modules( | ||
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False | ||
gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False | ||
) -> torch.fx.GraphModule: | ||
""" | ||
Replace TRT submodules with trt engine nodes. | ||
|
@@ -380,7 +380,16 @@ def inline_trt_modules( | |
num_outputs = len(trt_module_node.meta["val"]) | ||
# Insert a call_function node to perform inference on TRT engine | ||
with gm.graph.inserting_before(trt_module_node): | ||
if not cross_compile_flag: | ||
if cross_compile_module: | ||
engine_info = trt_module._pack_engine_info() | ||
engine_bytes = engine_info[ENGINE_IDX] | ||
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") | ||
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows | ||
trt_node = gm.graph.call_function( | ||
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, | ||
(trt_module_node.args, *engine_info), | ||
) | ||
else: | ||
# for the normal workflow: use the execute_engine node | ||
engine_name = f"{name}_engine" | ||
setattr(gm, engine_name, trt_module.engine) | ||
|
@@ -396,16 +405,6 @@ def inline_trt_modules( | |
engine_node.meta["val"] = CustomObjArgument( | ||
name=engine_node.name, class_fqn="" | ||
) | ||
else: | ||
# for the cross compile for windows workflow: use the no_op_placeholder node | ||
engine_info = trt_module._pack_engine_info() | ||
engine_bytes = engine_info[ENGINE_IDX] | ||
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") | ||
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows | ||
trt_node = gm.graph.call_function( | ||
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, | ||
(trt_module_node.args, *engine_info), | ||
) | ||
# set trt_node.meta with trt_module_node.meta | ||
assert num_outputs > 0 | ||
trt_node.meta["val"] = trt_module_node.meta["val"] | ||
|
@@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node( | |
name=engine_node.name, class_fqn="" | ||
) | ||
|
||
if len(no_op_placeholder_node.meta["val"]) == 1: | ||
with gm.graph.inserting_after(trt_node): | ||
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) | ||
getitem_output.meta["val"] = trt_node.meta["val"] | ||
no_op_placeholder_node.replace_all_uses_with(getitem_output) | ||
else: | ||
no_op_placeholder_node.replace_all_uses_with(trt_node) | ||
getitem_nodes = trt_node.users | ||
for idx, getitem_node in enumerate(getitem_nodes): | ||
getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
no_op_placeholder_node.replace_all_uses_with(trt_node) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a multi output testcase to the cross compile tests? |
||
getitem_nodes = trt_node.users | ||
for idx, getitem_node in enumerate(getitem_nodes): | ||
getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @narendasan this is the part which should address the bug |
||
gm.graph.erase_node(no_op_placeholder_node) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to unpack this list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would still need to unpack the list. Else while loading in windows it shows