Skip to content

Commit 46eaebe

Browse files
authored
update core.ops with pyboost (#1884)
1 parent 68a6247 commit 46eaebe

File tree

10 files changed

+374
-93
lines changed

10 files changed

+374
-93
lines changed

.github/workflows/ci_pipeline.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jobs:
8080
pip install -r download.txt
8181
- name: Test with pytest
8282
run: |
83-
pytest -c pytest.ini -m 'not download and not gpu_only' --ignore=tests/transformers tests/ut
83+
pytest -c pytest.ini -m 'not download and not gpu_only' --ignore=tests/transformers tests
8484
8585
release-test:
8686
needs: pylint-check
@@ -104,8 +104,8 @@ jobs:
104104
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${{matrix.ms_version}}/MindSpore/unified/x86_64/mindspore-${{matrix.ms_version}}-cp39-cp39-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com
105105
- name: Test with pytest
106106
run: |
107-
pytest -c pytest.ini -m 'not download and not gpu_only' --ignore=tests/transformers tests/ut
108-
# pytest -c pytest.ini -m 'not download and not gpu_only' tests/ut
107+
pytest -c pytest.ini -m 'not download and not gpu_only' --ignore=tests/transformers tests
108+
# pytest -c pytest.ini -m 'not download and not gpu_only' tests
109109
110110
transformers-model-test:
111111
needs: pylint-check

mindnlp/core/ops/array.py

+45-21
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ def argwhere(input):
1919
return ops.argwhere(input)
2020

2121
# cat
22+
has_cat = hasattr(mindspore.mint, 'cat')
2223
def cat(tensors, dim=0):
23-
if use_pyboost():
24+
if use_pyboost() and has_cat:
2425
return mindspore.mint.cat(tensors, dim)
2526
return ops.cat(tensors, dim)
2627

2728
# concat
29+
has_concat = hasattr(mindspore.mint, 'concat')
2830
def concat(tensors, dim=0):
2931
return cat(tensors, dim)
3032

@@ -37,7 +39,10 @@ def conj(input):
3739
return ops.conj(input)
3840

3941
# chunk
42+
has_chunk = hasattr(mindspore.mint, 'chunk')
4043
def chunk(input, chunks, dim=0):
44+
if use_pyboost() and has_chunk:
45+
return mindspore.mint.chunk(input, chunks, dim)
4146
return ops.chunk(input, chunks, dim)
4247

4348
# dsplit
@@ -50,8 +55,9 @@ def chunk(input, chunks, dim=0):
5055

5156

5257
# gather
58+
has_gather = hasattr(mindspore.mint, 'gather')
5359
def gather(input, dim, index):
54-
if use_pyboost():
60+
if use_pyboost() and has_gather:
5561
return mindspore.mint.gather(input, dim, index)
5662
index = ops.where(index < input.shape[dim], index, index - input.shape[dim])
5763
return ops.gather_elements(input, dim, index)
@@ -91,13 +97,17 @@ def inplace_index_add(input, dim, index, source):
9197

9298

9399
# index_select
100+
has_index_select = hasattr(mindspore.mint, 'index_select')
94101
def index_select(input, dim, index):
95-
if use_pyboost():
102+
if use_pyboost() and has_index_select:
96103
return mindspore.mint.index_select(input, dim, index)
97104
return ops.index_select(input, dim, index)
98105

99106
# masked_select
107+
has_masked_select = hasattr(mindspore.mint, 'masked_select')
100108
def masked_select(input, mask):
109+
if use_pyboost() and has_masked_select:
110+
return mindspore.mint.masked_select(input, mask)
101111
return ops.masked_select(input, mask)
102112

103113
# movedim
@@ -107,17 +117,19 @@ def masked_select(input, mask):
107117

108118

109119
# narrow
120+
has_narrow = hasattr(mindspore.mint, 'narrow')
110121
def narrow(input, dim, start, length):
111-
if use_pyboost():
122+
if use_pyboost() and has_narrow:
112123
return mindspore.mint.narrow(input, dim, start, length)
113124
return ops.narrow(input, dim, start, length)
114125

115126
# narrow_copy
116127

117128

118129
# nonzero
130+
has_nonzero = hasattr(mindspore.mint, 'nonzero')
119131
def nonzero(input, *, as_tuple=False):
120-
if use_pyboost():
132+
if use_pyboost() and has_nonzero:
121133
return mindspore.mint.nonzero(input, as_tuple)
122134
_nonzero = _get_cache_prim(ops.NonZero)()
123135
out = _nonzero(input)
@@ -128,37 +140,40 @@ def nonzero(input, *, as_tuple=False):
128140
return out
129141

130142
# permute
143+
has_permute = hasattr(mindspore.mint, 'permute')
131144
def permute(input, dims):
132-
if use_pyboost():
145+
if use_pyboost() and has_permute:
133146
return mindspore.mint.permute(input, dims)
134147
return ops.permute(input, dims)
135148

136149
# reshape
150+
has_reshape = hasattr(mindspore.mint, 'reshape')
137151
def reshape(input, shape):
152+
if use_pyboost() and has_reshape:
153+
return mindspore.mint.reshape(input, shape)
138154
return ops.reshape(input, shape)
139155

140156
def view(input, *shape):
141-
# if use_pyboost():
142-
# return mindspore.ops.auto_generate.gen_ops_prim.view_op(input, shape)
143-
return ops.reshape(input, shape)
157+
return reshape(input, shape)
144158

145159
# row_stack
146160

147161
# select
162+
has_select = hasattr(mindspore.mint, 'select')
148163
def select(input, dim, index):
164+
if use_pyboost() and has_select:
165+
return mindspore.mint.select(input, dim, index)
149166
slices = ()
150167
for _ in range(dim):
151168
slices += (slice(None, None, None),)
152169
slices += (index,)
153170
return input[slices]
154171

155172
# scatter
173+
has_scatter = hasattr(mindspore.mint, 'scatter')
156174
def scatter(input, dim, index, src):
157-
if use_pyboost():
158-
try:
159-
return mindspore.mint.scatter(input, dim, index, src)
160-
except:
161-
return mindspore.ops.auto_generate.gen_ops_prim.scatter_op(input, dim, index, src, 0)
175+
if use_pyboost() and has_scatter:
176+
return mindspore.mint.scatter(input, dim, index, src)
162177
if not isinstance(src, mindspore.Tensor):
163178
src = ops.full(index.shape, src, dtype=input.dtype)
164179
return ops.tensor_scatter_elements(input, index, src, dim)
@@ -179,8 +194,9 @@ def tf_scatter_nd(indices, updates, shape):
179194

180195

181196
# scatter_add
197+
has_scatter_add = hasattr(mindspore.mint, 'scatter_add')
182198
def scatter_add(input, dim, index, src):
183-
if use_pyboost():
199+
if use_pyboost() and has_scatter_add:
184200
return mindspore.mint.scatter_add(input, dim, index, src)
185201
return ops.tensor_scatter_elements(input, index, src, dim, 'add')
186202

@@ -196,8 +212,9 @@ def scatter_update(input, indices, updates):
196212
return ops.scatter_update(input, indices, updates)
197213

198214
# split
215+
has_split = hasattr(mindspore.mint, 'split')
199216
def split(tensor, split_size_or_sections, dim=0):
200-
if use_pyboost():
217+
if use_pyboost() and has_split:
201218
return mindspore.mint.split(tensor, split_size_or_sections, dim)
202219
return ops.split(tensor, split_size_or_sections, dim)
203220

@@ -206,8 +223,9 @@ def squeeze(input, dim=None):
206223
return ops.squeeze(input, dim)
207224

208225
# stack
226+
has_stack = hasattr(mindspore.mint, 'stack')
209227
def stack(tensors, dim=0):
210-
if use_pyboost():
228+
if use_pyboost() and has_stack:
211229
return mindspore.mint.stack(tensors, dim)
212230
return ops.stack(tensors, dim)
213231

@@ -235,20 +253,22 @@ def take(input, index):
235253

236254

237255
# tile
256+
has_tile = hasattr(mindspore.mint, 'tile')
238257
def tile(input, dims):
239-
if use_pyboost():
258+
if use_pyboost() and has_tile:
240259
return mindspore.mint.tile(input, dims)
241260
return ops.tile(input, dims)
242261

243262
# transpose
263+
has_transpose = hasattr(mindspore.mint, 'transpose')
244264
def transpose(input, dim0, dim1):
265+
if use_pyboost() and has_transpose:
266+
return mindspore.mint.transpose(input, dim0, dim1)
245267
ranks = list(range(input.ndim))
246268
rank0 = ranks[dim0]
247269
rank1 = ranks[dim1]
248270
ranks[dim0] = rank1
249271
ranks[dim1] = rank0
250-
if use_pyboost():
251-
return mindspore.ops.auto_generate.gen_ops_prim.transpose_op(input, tuple(ranks))
252272
return permute(input, tuple(ranks))
253273

254274
# unbind
@@ -258,7 +278,10 @@ def unbind(input, dim=0):
258278
# unravel_index
259279

260280
# unsqueeze
281+
has_unsqueeze = hasattr(mindspore.mint, 'unsqueeze')
261282
def unsqueeze(input, dim):
283+
if use_pyboost() and has_unsqueeze:
284+
return mindspore.mint.unsqueeze(input, dim)
262285
return ops.expand_dims(input, dim)
263286

264287
# vsplit
@@ -273,10 +296,11 @@ def vstack(input):
273296
)
274297

275298
# where
299+
has_where = hasattr(mindspore.mint, 'where')
276300
def where(condition, input, other):
277301
if ON_ORANGE_PI:
278302
return condition * input + (~condition) * other
279-
if use_pyboost():
303+
if use_pyboost() and has_where:
280304
return mindspore.mint.where(condition, input, other)
281305
return ops.where(condition, input, other)
282306

mindnlp/core/ops/blas.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1):
1919

2020

2121
# baddbmm
22+
has_baddbmm = hasattr(mindspore.mint, 'baddbmm')
2223
def baddbmm(input, batch1, batch2, *, beta=1, alpha=1):
24+
if use_pyboost() and has_baddbmm:
25+
return mindspore.mint.baddbmm(input, batch1, batch2, beta=beta, alpha=alpha)
2326
return ops.baddbmm(input, batch1, batch2, beta=beta, alpha=alpha)
2427

2528
# bmm
29+
has_bmm = hasattr(mindspore.mint, 'bmm')
2630
def bmm(input, other):
2731
if ON_ORANGE_PI:
2832
input = input.to(mindspore.float16)
2933
other = input.to(mindspore.float16)
30-
if use_pyboost():
34+
if use_pyboost() and has_bmm:
3135
return mindspore.mint.bmm(input, other)
3236
return ops.bmm(input, other)
3337

@@ -66,11 +70,12 @@ def dot(input, other):
6670
# lu_unpack
6771

6872
# matmul
73+
has_matmul = hasattr(mindspore.mint, 'matmul')
6974
def matmul(input, other):
7075
if ON_ORANGE_PI:
7176
input = input.to(mindspore.float16)
7277
other = other.to(mindspore.float16)
73-
if use_pyboost():
78+
if use_pyboost() and has_matmul:
7479
return mindspore.mint.matmul(input, other)
7580
return ops.matmul(input, other)
7681

@@ -90,7 +95,10 @@ def mm(input, other):
9095
# ormqr
9196

9297
# outer
98+
has_outer = hasattr(mindspore.mint, 'outer')
9399
def outer(input, vec2):
100+
if use_pyboost() and has_outer:
101+
return mindspore.mint.outer(input, vec2)
94102
return ops.outer(input, vec2)
95103

96104
# pinverse

mindnlp/core/ops/comparison.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ def argsort(input, dim=-1, descending=False, stable=False):
1313
return sort(input, dim=dim, descending=descending, stable=stable)[1]
1414

1515
# eq
16+
has_eq = hasattr(mindspore.mint, 'eq')
1617
def eq(input, other):
17-
if use_pyboost():
18+
if use_pyboost() and has_eq:
1819
return mindspore.mint.eq(input, other)
1920
return ops.eq(input, other)
2021

@@ -27,24 +28,28 @@ def ge(input, other):
2728
return ops.ge(input, other)
2829

2930
# gt
31+
has_gt = hasattr(mindspore.mint, 'gt')
3032
def gt(input, other):
31-
if use_pyboost():
33+
if use_pyboost() and has_gt:
3234
return mindspore.mint.gt(input, other)
3335
return ops.gt(input, other)
3436

3537
# greater
38+
has_greater = hasattr(mindspore.mint, 'greater')
3639
def greater(input, other):
3740
return gt(input, other)
3841

3942
# isclose
43+
has_isclose = hasattr(mindspore.mint, 'isclose')
4044
def isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
41-
if use_pyboost():
45+
if use_pyboost() and has_isclose:
4246
return mindspore.mint.isclose(input, other, rtol, atol, equal_nan)
4347
return mindspore.tensor(np.isclose(input.numpy(), other.numpy(), rtol, atol, equal_nan))
4448

4549
# isfinite
50+
has_isfinite = hasattr(mindspore.mint, 'isfinite')
4651
def isfinite(input):
47-
if use_pyboost():
52+
if use_pyboost() and has_isfinite:
4853
return mindspore.mint.isfinite(input)
4954
return ops.isfinite(input)
5055

@@ -79,8 +84,9 @@ def isnan(input):
7984
# kthvalue
8085

8186
# le
87+
has_le = hasattr(mindspore.mint, 'le')
8288
def le(input, other):
83-
if use_pyboost():
89+
if use_pyboost() and has_le:
8490
return mindspore.mint.le(input, other)
8591
return ops.le(input, other)
8692

@@ -89,8 +95,9 @@ def less_equal(input, other):
8995
return le(input, other)
9096

9197
# lt
98+
has_lt = hasattr(mindspore.mint, 'lt')
9299
def lt(input, other):
93-
if use_pyboost():
100+
if use_pyboost() and has_lt:
94101
return mindspore.mint.lt(input, other)
95102
return ops.lt(input, other)
96103

@@ -99,14 +106,16 @@ def less(input, other):
99106
return lt(input, other)
100107

101108
# maximum
109+
has_maximum = hasattr(mindspore.mint, 'maximum')
102110
def maximum(input, other):
103-
if use_pyboost():
111+
if use_pyboost() and has_maximum:
104112
return mindspore.mint.maximum(input, other)
105113
return ops.maximum(input, other)
106114

107115
# minimum
116+
has_minimum = hasattr(mindspore.mint, 'minimum')
108117
def minimum(input, other):
109-
if use_pyboost():
118+
if use_pyboost() and has_minimum:
110119
return mindspore.mint.minimum(input, other)
111120
return ops.minimum(input, other)
112121

@@ -120,8 +129,9 @@ def fmin(input, other):
120129
return ops.fmin(input, other)
121130

122131
# ne
132+
has_ne = hasattr(mindspore.mint, 'ne')
123133
def ne(input, other):
124-
if use_pyboost():
134+
if use_pyboost() and has_ne:
125135
return mindspore.mint.ne(input, other)
126136
return ops.ne(input, other)
127137

@@ -130,14 +140,16 @@ def not_equal(input, other):
130140
return ne(input, other)
131141

132142
# sort
143+
has_sort = hasattr(mindspore.mint, 'sort')
133144
def sort(input, *, dim=-1, descending=False, stable=False):
134-
if use_pyboost():
145+
if use_pyboost() and has_sort:
135146
return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
136147
return ops.sort(input, dim, descending)
137148

138149
# topk
150+
has_topk = hasattr(mindspore.mint, 'topk')
139151
def topk(input, k, dim=-1, largest=True, sorted=True):
140-
if use_pyboost():
152+
if use_pyboost() and has_topk:
141153
return mindspore.mint.topk(input, k, dim, largest, sorted)
142154
return ops.topk(input, k, dim, largest, sorted)
143155

0 commit comments

Comments
 (0)