Skip to content

Commit ff143f1

Browse files
committed
Optimize LazyValues and SparseValues with Caching Mechanism
Signed-off-by: Phoenix <[email protected]>
1 parent c56f4c7 commit ff143f1

File tree

1 file changed

+17
-4
lines changed
  • tools/onnx-graphsurgeon/onnx_graphsurgeon/ir

1 file changed

+17
-4
lines changed

tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,18 @@ def __init__(self, tensor):
231231
self.shape = get_onnx_tensor_shape(self.tensor)
232232
self.dtype = get_onnx_tensor_dtype(self.tensor)
233233
self.nbytes = misc.volume(self.shape) * get_itemsize(self.dtype)
234+
self._cached_values = None # Initialize the cache
234235

235236
def load(self):
236237
"""
237-
Load a numpy array from the underlying tensor values.
238+
Load a numpy array from the underlying tensor values, using cache.
238239
239240
Returns:
240241
np.array: A numpy array containing the values of the tensor.
241242
"""
243+
if self._cached_values is not None:
244+
return self._cached_values # Return cached data if available
245+
242246
import onnx
243247
import onnx.numpy_helper
244248
from onnx_graphsurgeon.importers.onnx_importer import (
@@ -254,7 +258,8 @@ def load(self):
254258
f"If this is not what you intended, please avoid accessing the values of this constant tensor."
255259
)
256260

257-
return np.array(onnx.numpy_helper.to_array(self.tensor))
261+
self._cached_values = np.array(onnx.numpy_helper.to_array(self.tensor))
262+
return self._cached_values
258263

259264
def __str__(self):
260265
return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype)
@@ -268,13 +273,20 @@ class SparseValues(LazyValues):
268273
A special object that represents constant tensor values that is sparse
269274
"""
270275

276+
def __init__(self, tensor):
277+
super().__init__(tensor)
278+
self._cached_values = None # Initialize the cache
279+
271280
def load(self):
272281
"""
273-
Load a numpy array from the sparse structure.
282+
Load a numpy array from the sparse structure, using cache.
274283
275284
Returns:
276285
np.array: A numpy array containing the values of the tensor.
277286
"""
287+
if self._cached_values is not None:
288+
return self._cached_values # Return cached data if available
289+
278290
import onnx
279291
import onnx.numpy_helper
280292
from onnx_graphsurgeon.importers.onnx_importer import (
@@ -316,7 +328,8 @@ def load(self):
316328
f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}"
317329
)
318330

319-
return values
331+
self._cached_values = values
332+
return self._cached_values
320333

321334
def __str__(self):
322335
return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype)

0 commit comments

Comments
 (0)