@@ -584,6 +584,7 @@ def save(
584
584
arg_inputs : Optional [Sequence [torch .Tensor ]] = None ,
585
585
kwarg_inputs : Optional [dict [str , Any ]] = None ,
586
586
retrace : bool = False ,
587
+ pickle_protocol : int = 2 ,
587
588
) -> None :
588
589
"""
589
590
Save the model to disk in the specified output format.
@@ -596,6 +597,7 @@ def save(
596
597
output_format (str): Format to save the model. Options include exported_program | torchscript.
597
598
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
598
599
This flag is experimental for now.
600
+ pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
599
601
"""
600
602
if isinstance (module , CudaGraphsTorchTensorRTModule ):
601
603
module = module .compiled_module
@@ -668,7 +670,9 @@ def save(
668
670
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
669
671
)
670
672
exp_program = export (module )
671
- torch .export .save (exp_program , file_path )
673
+ torch .export .save (
674
+ exp_program , file_path , pickle_protocol = pickle_protocol
675
+ )
672
676
else :
673
677
if arg_inputs is None :
674
678
raise ValueError (
@@ -680,4 +684,6 @@ def save(
680
684
kwargs = kwarg_inputs ,
681
685
strict = False ,
682
686
)
683
- torch .export .save (exp_program , file_path )
687
+ torch .export .save (
688
+ exp_program , file_path , pickle_protocol = pickle_protocol
689
+ )
0 commit comments