@@ -231,14 +231,18 @@ def __init__(self, tensor):
231
231
self .shape = get_onnx_tensor_shape (self .tensor )
232
232
self .dtype = get_onnx_tensor_dtype (self .tensor )
233
233
self .nbytes = misc .volume (self .shape ) * get_itemsize (self .dtype )
234
+ self ._cached_values = None # Initialize the cache
234
235
235
236
def load (self ):
236
237
"""
237
- Load a numpy array from the underlying tensor values.
238
+ Load a numpy array from the underlying tensor values, using cache .
238
239
239
240
Returns:
240
241
np.array: A numpy array containing the values of the tensor.
241
242
"""
243
+ if self ._cached_values is not None :
244
+ return self ._cached_values # Return cached data if available
245
+
242
246
import onnx
243
247
import onnx .numpy_helper
244
248
from onnx_graphsurgeon .importers .onnx_importer import (
@@ -254,7 +258,8 @@ def load(self):
254
258
f"If this is not what you intended, please avoid accessing the values of this constant tensor."
255
259
)
256
260
257
- return np .array (onnx .numpy_helper .to_array (self .tensor ))
261
+ self ._cached_values = np .array (onnx .numpy_helper .to_array (self .tensor ))
262
+ return self ._cached_values
258
263
259
264
def __str__ (self ):
260
265
return "LazyValues (shape={:}, dtype={:})" .format (self .shape , self .dtype )
@@ -268,13 +273,20 @@ class SparseValues(LazyValues):
268
273
A special object that represents constant tensor values that is sparse
269
274
"""
270
275
276
+ def __init__ (self , tensor ):
277
+ super ().__init__ (tensor )
278
+ self ._cached_values = None # Initialize the cache
279
+
271
280
def load (self ):
272
281
"""
273
- Load a numpy array from the sparse structure.
282
+ Load a numpy array from the sparse structure, using cache .
274
283
275
284
Returns:
276
285
np.array: A numpy array containing the values of the tensor.
277
286
"""
287
+ if self ._cached_values is not None :
288
+ return self ._cached_values # Return cached data if available
289
+
278
290
import onnx
279
291
import onnx .numpy_helper
280
292
from onnx_graphsurgeon .importers .onnx_importer import (
@@ -316,7 +328,8 @@ def load(self):
316
328
f"Unsupported index data dims { self .tensor .indices .dims } in { self .tensor .values .name } "
317
329
)
318
330
319
- return values
331
+ self ._cached_values = values
332
+ return self ._cached_values
320
333
321
334
def __str__ (self ):
322
335
return "SparseValues (shape={:}, dtype={:})" .format (self .shape , self .dtype )
0 commit comments