diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index b40968f4..896b1f82 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -1245,12 +1245,31 @@ def should_eval_foldable(tensor): else: names = [t.name for t in graph_clone.outputs] try: + import os + import tempfile + import onnx import onnxruntime as onnxrt + onnx_model = export_onnx(graph_clone, do_type_check=False) + if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + tmp_dir = tempfile.TemporaryDirectory() + tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") + location = os.path.basename(tmp_path) + ".data" + if os.path.exists(location): + os.remove(location) + onnx.save( + onnx_model, + tmp_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=location, + ) + onnx_model = tmp_path + else: + onnx_model = onnx_model.SerializeToString() + sess = onnxrt.InferenceSession( - export_onnx( - graph_clone, do_type_check=False - ).SerializeToString(), + onnx_model, providers=ORT_PROVIDERS, ) values = sess.run(names, {})