Skip to content

slight code reorg and bug correction for cross_compile #3472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def save_cross_compiled_exported_program(

from torch_tensorrt.dynamo._exporter import export

exp_program = export(gm, cross_compile_flag=True)
exp_program = export(gm, cross_compile_module=True)
torch.export.save(exp_program, file_path)
logger.debug(f"successfully saved the module for windows at {file_path}")

Expand Down
49 changes: 21 additions & 28 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@

def export(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
cross_compile_module: Optional[bool] = False,
) -> ExportedProgram:
"""Export the result of TensorRT compilation into the desired output format.
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
"""
patched_module = transform(gm, cross_compile_flag)
patched_module = transform(gm, cross_compile_module)
exp_program = create_trt_exp_program(patched_module)
return exp_program


def transform(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
cross_compile_module: Optional[bool] = False,
) -> torch.fx.GraphModule:
"""
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
Expand All @@ -48,7 +48,7 @@ def transform(
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
Returns an inlined torch.fx.GraphModule
"""
Expand All @@ -57,7 +57,7 @@ def transform(
gm = copy.deepcopy(gm)

# Inline TensorRT submodules
inline_trt_modules(gm, cross_compile_flag)
inline_trt_modules(gm, cross_compile_module)

# Inline pytorch submodules
inline_torch_modules(gm)
Expand Down Expand Up @@ -356,7 +356,7 @@ def create_trt_exp_program(


def inline_trt_modules(
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False
gm: torch.fx.GraphModule, cross_compile_module: Optional[bool] = False
) -> torch.fx.GraphModule:
"""
Replace TRT submodules with trt engine nodes.
Expand All @@ -380,7 +380,16 @@ def inline_trt_modules(
num_outputs = len(trt_module_node.meta["val"])
# Insert a call_function node to perform inference on TRT engine
with gm.graph.inserting_before(trt_module_node):
if not cross_compile_flag:
if cross_compile_module:
engine_info = trt_module._pack_engine_info()
engine_bytes = engine_info[ENGINE_IDX]
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
trt_node = gm.graph.call_function(
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
(trt_module_node.args, *engine_info),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to unpack this list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would still need to unpack the list. Else while loading in windows it shows

File "C:\Users\abose\Documents\work\TensorRT\torchTRT\Lib\site-packages\torch\_export\serde\serialize.py", line 2258, in deserialize_inputs
  args.append(actual_args[schema_arg.name])
        ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
KeyError: 'name'

)
else:
# for the normal workflow: use the execute_engine node
engine_name = f"{name}_engine"
setattr(gm, engine_name, trt_module.engine)
Expand All @@ -396,16 +405,6 @@ def inline_trt_modules(
engine_node.meta["val"] = CustomObjArgument(
name=engine_node.name, class_fqn=""
)
else:
# for the cross compile for windows workflow: use the no_op_placeholder node
engine_info = trt_module._pack_engine_info()
engine_bytes = engine_info[ENGINE_IDX]
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
trt_node = gm.graph.call_function(
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
(trt_module_node.args, *engine_info),
)
# set trt_node.meta with trt_module_node.meta
assert num_outputs > 0
trt_node.meta["val"] = trt_module_node.meta["val"]
Expand Down Expand Up @@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node(
name=engine_node.name, class_fqn=""
)

if len(no_op_placeholder_node.meta["val"]) == 1:
with gm.graph.inserting_after(trt_node):
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
getitem_output.meta["val"] = trt_node.meta["val"]
no_op_placeholder_node.replace_all_uses_with(getitem_output)
else:
no_op_placeholder_node.replace_all_uses_with(trt_node)
getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]
no_op_placeholder_node.replace_all_uses_with(trt_node)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a multi output testcase to the cross compile tests?

getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan this is the part which should address the bug

gm.graph.erase_node(no_op_placeholder_node)

Expand Down
21 changes: 21 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,24 @@ def __setstate__(self, serialized_state: List[str]) -> Any:

def __getstate__(self) -> Any:
pass


@torch.library.custom_op(
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
)
def no_op_placeholder_for_execute_engine(
inputs: List[torch.Tensor],
abi_version: str,
name: str,
serialized_device_info: str,
serialized_engine: str,
serialized_in_binding_names: str,
serialized_out_binding_names: str,
serialized_hardware_compatible: str,
serialized_metadata: str,
serialized_target_platform: str,
serialized_require_output_allocator: str,
) -> List[torch.Tensor]:
raise RuntimeError(
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
)
20 changes: 0 additions & 20 deletions py/torch_tensorrt/runtime/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,3 @@ def _get_most_compatible_device(
best_match = candidate

return best_match


@torch.library.custom_op(
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
)
def no_op_placeholder_for_execute_engine(
inputs: List[torch.Tensor],
abi_version: str,
name: str,
serialized_device_info: str,
serialized_engine: str,
serialized_in_binding_names: str,
serialized_out_binding_names: str,
serialized_hardware_compatible: str,
serialized_metadata: str,
serialized_target_platform: str,
) -> List[torch.Tensor]:
raise RuntimeError(
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
)
28 changes: 28 additions & 0 deletions tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,31 @@ def forward(self, a, b):
)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")

@unittest.skipIf(
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@pytest.mark.unit
def test_dynamo_cross_compile_for_windows_multiple_output(self):
class Add(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b), torch.add(a, b)

model = Add().eval().cuda()
inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda())
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
exp_program = torch.export.export(model, inputs)
compile_spec = {
"inputs": inputs,
"min_block_size": 1,
}
try:
trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows(
exp_program, **compile_spec
)
torch_tensorrt.dynamo.save_cross_compiled_exported_program(
trt_gm, file_path=trt_ep_path
)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")