Skip to content

Commit e104318

Browse files
authored
chore: miscellaneous fixes for handling graph breaks (#3488)
1 parent a8ecd79 commit e104318

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

py/torch_tensorrt/_compile.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def save(
584584
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
585585
kwarg_inputs: Optional[dict[str, Any]] = None,
586586
retrace: bool = False,
587+
pickle_protocol: int = 2,
587588
) -> None:
588589
"""
589590
Save the model to disk in the specified output format.
@@ -596,6 +597,7 @@ def save(
596597
output_format (str): Format to save the model. Options include exported_program | torchscript.
597598
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
598599
This flag is experimental for now.
600+
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
599601
"""
600602
if isinstance(module, CudaGraphsTorchTensorRTModule):
601603
module = module.compiled_module
@@ -668,7 +670,9 @@ def save(
668670
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
669671
)
670672
exp_program = export(module)
671-
torch.export.save(exp_program, file_path)
673+
torch.export.save(
674+
exp_program, file_path, pickle_protocol=pickle_protocol
675+
)
672676
else:
673677
if arg_inputs is None:
674678
raise ValueError(
@@ -680,4 +684,6 @@ def save(
680684
kwargs=kwarg_inputs,
681685
strict=False,
682686
)
683-
torch.export.save(exp_program, file_path)
687+
torch.export.save(
688+
exp_program, file_path, pickle_protocol=pickle_protocol
689+
)

py/torch_tensorrt/dynamo/conversion/truncate_double.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def repair_double_inputs(
195195

196196
# If the data type of the input is long/double, insert necessary
197197
# casts to replace the operation
198-
if param.dtype == torch.float64:
198+
if isinstance(param, torch.Tensor) and param.dtype == torch.float64:
199199
# Ensure outputs are only repaired once per submodule to avoid
200200
# unnecessary ops showing up in the graph
201201
if not repaired_outputs_once:

py/torch_tensorrt/dynamo/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
419419
"""
420420
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
421421
"""
422-
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
422+
if isinstance(tensor, (torch.Tensor, FakeTensor)):
423+
return tensor.dtype
424+
elif isinstance(tensor, (int, float, bool)):
423425
return torch.tensor(tensor).dtype
424426
elif isinstance(tensor, torch.SymInt):
425427
return torch.int64
@@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
791793
output_dtypes.append(dtype.float32)
792794
else:
793795
output_dtypes.append(dtype._from(output_meta.dtype))
796+
elif isinstance(output_meta, torch.SymInt):
797+
output_dtypes.append(dtype.int64)
794798
elif "tensor_meta" in output.meta:
795799
output_meta = output.meta["tensor_meta"]
796800
output_dtypes.append(dtype._from(output_meta.dtype))

0 commit comments

Comments
 (0)