From 1915755c7bc5a2232513091b665dccc1ba6cbe6c Mon Sep 17 00:00:00 2001 From: inisis Date: Thu, 16 May 2024 11:42:56 +0000 Subject: [PATCH] fix inference fail when protobuf size larger than 2GB Signed-off-by: inisis --- .../onnx_graphsurgeon/ir/graph.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index b40968f4..896b1f82 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -1245,12 +1245,31 @@ def should_eval_foldable(tensor): else: names = [t.name for t in graph_clone.outputs] try: + import os + import tempfile + import onnx import onnxruntime as onnxrt + onnx_model = export_onnx(graph_clone, do_type_check=False) + if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + tmp_dir = tempfile.TemporaryDirectory() + tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") + location = os.path.basename(tmp_path) + ".data" + if os.path.exists(location): + os.remove(location) + onnx.save( + onnx_model, + tmp_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=location, + ) + onnx_model = tmp_path + else: + onnx_model = onnx_model.SerializeToString() + sess = onnxrt.InferenceSession( - export_onnx( - graph_clone, do_type_check=False - ).SerializeToString(), + onnx_model, providers=ORT_PROVIDERS, ) values = sess.run(names, {})