@@ -19,12 +19,14 @@ def argwhere(input):
19
19
return ops .argwhere (input )
20
20
21
21
# cat
22
+ has_cat = hasattr (mindspore .mint , 'cat' )
22
23
def cat (tensors , dim = 0 ):
23
- if use_pyboost ():
24
+ if use_pyboost () and has_cat :
24
25
return mindspore .mint .cat (tensors , dim )
25
26
return ops .cat (tensors , dim )
26
27
27
28
# concat
29
+ has_concat = hasattr (mindspore .mint , 'concat' )
28
30
def concat (tensors , dim = 0 ):
29
31
return cat (tensors , dim )
30
32
@@ -37,7 +39,10 @@ def conj(input):
37
39
return ops .conj (input )
38
40
39
41
# chunk
42
+ has_chunk = hasattr (mindspore .mint , 'chunk' )
40
43
def chunk (input , chunks , dim = 0 ):
44
+ if use_pyboost () and has_chunk :
45
+ return mindspore .mint .chunk (input , chunks , dim )
41
46
return ops .chunk (input , chunks , dim )
42
47
43
48
# dsplit
@@ -50,8 +55,9 @@ def chunk(input, chunks, dim=0):
50
55
51
56
52
57
# gather
58
+ has_gather = hasattr (mindspore .mint , 'gather' )
53
59
def gather (input , dim , index ):
54
- if use_pyboost ():
60
+ if use_pyboost () and has_gather :
55
61
return mindspore .mint .gather (input , dim , index )
56
62
index = ops .where (index < input .shape [dim ], index , index - input .shape [dim ])
57
63
return ops .gather_elements (input , dim , index )
@@ -91,13 +97,17 @@ def inplace_index_add(input, dim, index, source):
91
97
92
98
93
99
# index_select
100
+ has_index_select = hasattr (mindspore .mint , 'index_select' )
94
101
def index_select (input , dim , index ):
95
- if use_pyboost ():
102
+ if use_pyboost () and has_index_select :
96
103
return mindspore .mint .index_select (input , dim , index )
97
104
return ops .index_select (input , dim , index )
98
105
99
106
# masked_select
107
+ has_masked_select = hasattr (mindspore .mint , 'masked_select' )
100
108
def masked_select (input , mask ):
109
+ if use_pyboost () and has_masked_select :
110
+ return mindspore .mint .masked_select (input , mask )
101
111
return ops .masked_select (input , mask )
102
112
103
113
# movedim
@@ -107,17 +117,19 @@ def masked_select(input, mask):
107
117
108
118
109
119
# narrow
120
+ has_narrow = hasattr (mindspore .mint , 'narrow' )
110
121
def narrow (input , dim , start , length ):
111
- if use_pyboost ():
122
+ if use_pyboost () and has_narrow :
112
123
return mindspore .mint .narrow (input , dim , start , length )
113
124
return ops .narrow (input , dim , start , length )
114
125
115
126
# narrow_copy
116
127
117
128
118
129
# nonzero
130
+ has_nonzero = hasattr (mindspore .mint , 'nonzero' )
119
131
def nonzero (input , * , as_tuple = False ):
120
- if use_pyboost ():
132
+ if use_pyboost () and has_nonzero :
121
133
return mindspore .mint .nonzero (input , as_tuple )
122
134
_nonzero = _get_cache_prim (ops .NonZero )()
123
135
out = _nonzero (input )
@@ -128,37 +140,40 @@ def nonzero(input, *, as_tuple=False):
128
140
return out
129
141
130
142
# permute
143
+ has_permute = hasattr (mindspore .mint , 'permute' )
131
144
def permute (input , dims ):
132
- if use_pyboost ():
145
+ if use_pyboost () and has_permute :
133
146
return mindspore .mint .permute (input , dims )
134
147
return ops .permute (input , dims )
135
148
136
149
# reshape
150
+ has_reshape = hasattr (mindspore .mint , 'reshape' )
137
151
def reshape (input , shape ):
152
+ if use_pyboost () and has_reshape :
153
+ return mindspore .mint .reshape (input , shape )
138
154
return ops .reshape (input , shape )
139
155
140
156
def view (input , * shape ):
141
- # if use_pyboost():
142
- # return mindspore.ops.auto_generate.gen_ops_prim.view_op(input, shape)
143
- return ops .reshape (input , shape )
157
+ return reshape (input , shape )
144
158
145
159
# row_stack
146
160
147
161
# select
162
+ has_select = hasattr (mindspore .mint , 'select' )
148
163
def select (input , dim , index ):
164
+ if use_pyboost () and has_select :
165
+ return mindspore .mint .select (input , dim , index )
149
166
slices = ()
150
167
for _ in range (dim ):
151
168
slices += (slice (None , None , None ),)
152
169
slices += (index ,)
153
170
return input [slices ]
154
171
155
172
# scatter
173
+ has_scatter = hasattr (mindspore .mint , 'scatter' )
156
174
def scatter (input , dim , index , src ):
157
- if use_pyboost ():
158
- try :
159
- return mindspore .mint .scatter (input , dim , index , src )
160
- except :
161
- return mindspore .ops .auto_generate .gen_ops_prim .scatter_op (input , dim , index , src , 0 )
175
+ if use_pyboost () and has_scatter :
176
+ return mindspore .mint .scatter (input , dim , index , src )
162
177
if not isinstance (src , mindspore .Tensor ):
163
178
src = ops .full (index .shape , src , dtype = input .dtype )
164
179
return ops .tensor_scatter_elements (input , index , src , dim )
@@ -179,8 +194,9 @@ def tf_scatter_nd(indices, updates, shape):
179
194
180
195
181
196
# scatter_add
197
+ has_scatter_add = hasattr (mindspore .mint , 'scatter_add' )
182
198
def scatter_add (input , dim , index , src ):
183
- if use_pyboost ():
199
+ if use_pyboost () and has_scatter_add :
184
200
return mindspore .mint .scatter_add (input , dim , index , src )
185
201
return ops .tensor_scatter_elements (input , index , src , dim , 'add' )
186
202
@@ -196,8 +212,9 @@ def scatter_update(input, indices, updates):
196
212
return ops .scatter_update (input , indices , updates )
197
213
198
214
# split
215
+ has_split = hasattr (mindspore .mint , 'split' )
199
216
def split (tensor , split_size_or_sections , dim = 0 ):
200
- if use_pyboost ():
217
+ if use_pyboost () and has_split :
201
218
return mindspore .mint .split (tensor , split_size_or_sections , dim )
202
219
return ops .split (tensor , split_size_or_sections , dim )
203
220
@@ -206,8 +223,9 @@ def squeeze(input, dim=None):
206
223
return ops .squeeze (input , dim )
207
224
208
225
# stack
226
+ has_stack = hasattr (mindspore .mint , 'stack' )
209
227
def stack (tensors , dim = 0 ):
210
- if use_pyboost ():
228
+ if use_pyboost () and has_stack :
211
229
return mindspore .mint .stack (tensors , dim )
212
230
return ops .stack (tensors , dim )
213
231
@@ -235,20 +253,22 @@ def take(input, index):
235
253
236
254
237
255
# tile
256
+ has_tile = hasattr (mindspore .mint , 'tile' )
238
257
def tile (input , dims ):
239
- if use_pyboost ():
258
+ if use_pyboost () and has_tile :
240
259
return mindspore .mint .tile (input , dims )
241
260
return ops .tile (input , dims )
242
261
243
262
# transpose
263
+ has_transpose = hasattr (mindspore .mint , 'transpose' )
244
264
def transpose (input , dim0 , dim1 ):
265
+ if use_pyboost () and has_transpose :
266
+ return mindspore .mint .transpose (input , dim0 , dim1 )
245
267
ranks = list (range (input .ndim ))
246
268
rank0 = ranks [dim0 ]
247
269
rank1 = ranks [dim1 ]
248
270
ranks [dim0 ] = rank1
249
271
ranks [dim1 ] = rank0
250
- if use_pyboost ():
251
- return mindspore .ops .auto_generate .gen_ops_prim .transpose_op (input , tuple (ranks ))
252
272
return permute (input , tuple (ranks ))
253
273
254
274
# unbind
@@ -258,7 +278,10 @@ def unbind(input, dim=0):
258
278
# unravel_index
259
279
260
280
# unsqueeze
281
+ has_unsqueeze = hasattr (mindspore .mint , 'unsqueeze' )
261
282
def unsqueeze (input , dim ):
283
+ if use_pyboost () and has_unsqueeze :
284
+ return mindspore .mint .unsqueeze (input , dim )
262
285
return ops .expand_dims (input , dim )
263
286
264
287
# vsplit
@@ -273,10 +296,11 @@ def vstack(input):
273
296
)
274
297
275
298
# where
299
+ has_where = hasattr (mindspore .mint , 'where' )
276
300
def where (condition , input , other ):
277
301
if ON_ORANGE_PI :
278
302
return condition * input + (~ condition ) * other
279
- if use_pyboost ():
303
+ if use_pyboost () and has_where :
280
304
return mindspore .mint .where (condition , input , other )
281
305
return ops .where (condition , input , other )
282
306
0 commit comments