@@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
133
133
ThrowHelper . ThrowArgument_ConcatenateTooFewTensors ( ) ;
134
134
135
135
if ( dimension < - 1 || dimension > tensors [ 0 ] . Rank )
136
- ThrowHelper . ThrowArgument_InvalidAxis ( ) ;
136
+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
137
137
138
- // Calculate total space needed.
139
- nint totalLength = 0 ;
140
- for ( int i = 0 ; i < tensors . Length ; i ++ )
141
- totalLength += tensors [ i ] . FlattenedLength ;
138
+ Tensor < T > tensor ;
142
139
143
- nint sumOfAxis = 0 ;
144
140
// If axis != -1, make sure all dimensions except the one to concatenate on match.
145
141
if ( dimension != - 1 )
146
142
{
147
- sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
143
+ nint sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
148
144
for ( int i = 1 ; i < tensors . Length ; i ++ )
149
145
{
150
146
if ( tensors [ 0 ] . Rank != tensors [ i ] . Rank )
@@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
157
153
ThrowHelper . ThrowArgument_InvalidConcatenateShape ( ) ;
158
154
}
159
155
}
160
- sumOfAxis += tensors [ i ] . Lengths [ dimension ] ;
156
+ checked
157
+ {
158
+ sumOfAxis += tensors [ i ] . Lengths [ dimension ] ;
159
+ }
161
160
}
162
- }
163
161
164
- Tensor < T > tensor ;
165
- if ( dimension == - 1 )
166
- {
167
- tensor = Tensor . Create < T > ( [ totalLength ] ) ;
168
- }
169
- else
170
- {
171
162
nint [ ] lengths = new nint [ tensors [ 0 ] . Rank ] ;
172
163
tensors [ 0 ] . Lengths . CopyTo ( lengths ) ;
173
164
lengths [ dimension ] = sumOfAxis ;
174
165
tensor = Tensor . Create < T > ( lengths ) ;
175
166
}
167
+ else
168
+ {
169
+ // Calculate total space needed.
170
+ nint totalLength = 0 ;
171
+ for ( int i = 0 ; i < tensors . Length ; i ++ )
172
+ {
173
+ checked
174
+ {
175
+ totalLength += tensors [ i ] . FlattenedLength ;
176
+ }
177
+ }
178
+
179
+ tensor = Tensor . Create < T > ( [ totalLength ] ) ;
180
+ }
176
181
177
182
ConcatenateOnDimension ( dimension , tensors , tensor ) ;
178
183
return tensor ;
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
201
206
ThrowHelper . ThrowArgument_ConcatenateTooFewTensors ( ) ;
202
207
203
208
if ( dimension < - 1 || dimension > tensors [ 0 ] . Rank )
204
- ThrowHelper . ThrowArgument_InvalidAxis ( ) ;
209
+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
205
210
206
211
// Calculate total space needed.
207
212
nint totalLength = 0 ;
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
212
217
if ( dimension != - 1 )
213
218
{
214
219
nint sumOfAxis = tensors [ 0 ] . Lengths [ dimension ] ;
220
+ int rank = tensors [ 0 ] . Rank ;
215
221
for ( int i = 1 ; i < tensors . Length ; i ++ )
216
222
{
217
- if ( tensors [ 0 ] . Rank != tensors [ i ] . Rank )
223
+ if ( rank != tensors [ i ] . Rank )
218
224
ThrowHelper . ThrowArgument_InvalidConcatenateShape ( ) ;
219
- for ( int j = 0 ; j < tensors [ 0 ] . Rank ; j ++ )
225
+ for ( int j = 0 ; j < rank ; j ++ )
220
226
{
221
227
if ( j != dimension )
222
228
{
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
228
234
}
229
235
230
236
// Make sure the destination tensor has the correct shape.
231
- nint [ ] lengths = new nint [ tensors [ 0 ] . Rank ] ;
237
+ nint [ ] lengths = new nint [ rank ] ;
232
238
tensors [ 0 ] . Lengths . CopyTo ( lengths ) ;
233
239
lengths [ dimension ] = sumOfAxis ;
234
240
@@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint
339
345
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns>
340
346
public static Tensor < T > Create < T > ( IEnumerable < T > enumerable , bool pinned = false )
341
347
{
348
+ T [ ] array = enumerable . ToArray ( ) ;
349
+
342
350
if ( pinned )
343
351
{
344
- T [ ] array = enumerable . ToArray ( ) ;
345
-
346
352
Tensor < T > tensor = CreateUninitialized < T > ( [ array . Length ] , pinned ) ;
347
353
array . CopyTo ( tensor . _values ) ;
348
354
349
355
return tensor ;
350
356
}
351
357
else
352
358
{
353
- T [ ] array = enumerable . ToArray ( ) ;
354
359
return Create ( array ) ;
355
360
}
356
361
}
@@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan
364
369
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns>
365
370
public static Tensor < T > Create < T > ( IEnumerable < T > enumerable , scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool pinned = false )
366
371
{
372
+ T [ ] array = enumerable . ToArray ( ) ;
373
+
367
374
if ( pinned )
368
375
{
369
- T [ ] array = enumerable . ToArray ( ) ;
370
-
371
376
Tensor < T > tensor = CreateUninitialized < T > ( lengths , strides , pinned ) ;
372
377
array . CopyTo ( tensor . _values ) ;
373
378
374
379
return tensor ;
375
380
}
376
381
else
377
382
{
378
- T [ ] array = enumerable . ToArray ( ) ;
379
383
return Create ( array , lengths , strides ) ;
380
384
}
381
385
}
@@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y)
620
624
/// <param name="value">Value to update in the <paramref name="tensor"/>.</param>
621
625
public static ref readonly TensorSpan < T > FilteredUpdate < T > ( in this TensorSpan < T > tensor , scoped in ReadOnlyTensorSpan < bool > filter , T value )
622
626
{
623
- if ( filter . Lengths . Length != tensor . Lengths . Length )
624
- ThrowHelper . ThrowArgument_DimensionsNotSame ( nameof ( filter ) ) ;
625
-
626
- Span < T > srcSpan = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
627
- Span < bool > filterSpan = MemoryMarshal . CreateSpan ( ref filter . _reference , ( int ) tensor . _shape . LinearLength ) ;
628
-
629
- for ( int i = 0 ; i < filterSpan . Length ; i ++ )
630
- {
631
- if ( filterSpan [ i ] )
632
- {
633
- srcSpan [ i ] = value ;
634
- }
635
- }
636
-
627
+ TensorOperation . ValidateCompatibility ( filter , tensor ) ;
628
+ TensorOperation . Invoke < TensorOperation . FilteredUpdate < T > , bool , T , T > ( filter , value , tensor ) ;
637
629
return ref tensor ;
638
630
}
639
631
@@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T>
646
638
/// <param name="values">Values to update in the <paramref name="tensor"/>.</param>
647
639
public static ref readonly TensorSpan < T > FilteredUpdate < T > ( in this TensorSpan < T > tensor , scoped in ReadOnlyTensorSpan < bool > filter , scoped in ReadOnlyTensorSpan < T > values )
648
640
{
649
- if ( filter . Lengths . Length != tensor . Lengths . Length )
650
- ThrowHelper . ThrowArgument_DimensionsNotSame ( nameof ( filter ) ) ;
651
- if ( values . Rank != 1 )
652
- ThrowHelper . ThrowArgument_1DTensorRequired ( nameof ( values ) ) ;
653
-
654
- Span < T > dstSpan = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
655
- Span < bool > filterSpan = MemoryMarshal . CreateSpan ( ref filter . _reference , ( int ) tensor . _shape . LinearLength ) ;
656
- Span < T > valuesSpan = MemoryMarshal . CreateSpan ( ref values . _reference , ( int ) values . _shape . LinearLength ) ;
657
-
658
- int index = 0 ;
659
- for ( int i = 0 ; i < filterSpan . Length ; i ++ )
660
- {
661
- if ( filterSpan [ i ] )
662
- {
663
- dstSpan [ i ] = valuesSpan [ index ++ ] ;
664
- }
665
- }
666
-
641
+ TensorOperation . ValidateCompatibility ( filter , values , tensor ) ;
642
+ TensorOperation . Invoke < TensorOperation . FilteredUpdate < T > , bool , T , T > ( filter , values , tensor ) ;
667
643
return ref tensor ;
668
644
}
669
645
#endregion
@@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
1409
1385
}
1410
1386
else
1411
1387
{
1388
+ if ( ! dimensions . IsEmpty && dimensions . Length != tensor . Lengths . Length )
1389
+ ThrowHelper . ThrowArgument_PermuteAxisOrder ( ) ;
1390
+
1412
1391
scoped Span < nint > newLengths = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < nint > lengthsRentedBuffer ) ;
1413
1392
scoped Span < nint > newStrides = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < nint > stridesRentedBuffer ) ;
1414
1393
scoped Span < int > newLinearOrder = TensorOperation . RentedBuffer . CreateUninitialized ( tensor . Rank , out TensorOperation . RentedBuffer < int > linearOrderRentedBuffer ) ;
@@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
1426
1405
}
1427
1406
else
1428
1407
{
1429
- if ( dimensions . Length != tensor . Lengths . Length )
1430
- ThrowHelper . ThrowArgument_PermuteAxisOrder ( ) ;
1431
-
1432
1408
for ( int i = 0 ; i < dimensions . Length ; i ++ )
1433
1409
{
1410
+ if ( dimensions [ i ] >= tensor . Lengths . Length || dimensions [ i ] < 0 )
1411
+ {
1412
+ ThrowHelper . ThrowArgument_InvalidDimension ( ) ;
1413
+ }
1434
1414
newLengths [ i ] = tensor . Lengths [ dimensions [ i ] ] ;
1435
1415
newStrides [ i ] = tensor . Strides [ dimensions [ i ] ] ;
1436
1416
newLinearOrder [ i ] = tensor . _shape . LinearRankOrder [ dimensions [ i ] ] ;
@@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
1467
1447
1468
1448
nint [ ] newLengths = lengths . ToArray ( ) ;
1469
1449
// Calculate wildcard info.
1470
- if ( lengths . Contains ( - 1 ) )
1450
+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1451
+ if ( wildcardIndex >= 0 )
1471
1452
{
1472
1453
if ( lengths . Count ( - 1 ) > 1 )
1473
1454
ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
1479
1460
tempTotal /= lengths [ i ] ;
1480
1461
}
1481
1462
}
1482
- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1463
+ newLengths [ wildcardIndex ] = tempTotal ;
1483
1464
}
1484
1465
1485
1466
nint tempLinear = TensorPrimitives . Product ( newLengths ) ;
@@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
1538
1519
}
1539
1520
1540
1521
nint [ ] newLengths = lengths . ToArray ( ) ;
1541
- // Calculate wildcard info.
1542
- if ( lengths . Contains ( - 1 ) )
1522
+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1523
+ if ( wildcardIndex >= 0 )
1543
1524
{
1544
1525
if ( lengths . Count ( - 1 ) > 1 )
1545
1526
ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
1551
1532
tempTotal /= lengths [ i ] ;
1552
1533
}
1553
1534
}
1554
- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1535
+ newLengths [ wildcardIndex ] = tempTotal ;
1555
1536
1556
1537
}
1557
1538
@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
1615
1596
1616
1597
nint [ ] newLengths = lengths . ToArray ( ) ;
1617
1598
// Calculate wildcard info.
1618
- if ( lengths . Contains ( - 1 ) )
1599
+ int wildcardIndex = lengths . IndexOf ( - 1 ) ;
1600
+ if ( wildcardIndex >= 0 )
1619
1601
{
1620
1602
if ( lengths . Count ( - 1 ) > 1 )
1621
1603
ThrowHelper . ThrowArgument_OnlyOneWildcard ( ) ;
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
1627
1609
tempTotal /= lengths [ i ] ;
1628
1610
}
1629
1611
}
1630
- newLengths [ lengths . IndexOf ( - 1 ) ] = tempTotal ;
1612
+ newLengths [ wildcardIndex ] = tempTotal ;
1631
1613
1632
1614
}
1633
1615
@@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
1701
1683
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
1702
1684
public static void ResizeTo < T > ( scoped in Tensor < T > tensor , in TensorSpan < T > destination )
1703
1685
{
1704
- ReadOnlySpan < T > span = MemoryMarshal . CreateSpan ( ref Unsafe . Add ( ref tensor . AsTensorSpan ( ) . _reference , tensor . _start ) , ( int ) tensor . _values . Length - tensor . _start ) ;
1705
- Span < T > ospan = MemoryMarshal . CreateSpan ( ref destination . _reference , ( int ) destination . _shape . LinearLength ) ;
1706
- if ( ospan . Length >= span . Length )
1707
- span . CopyTo ( ospan ) ;
1708
- else
1709
- span . Slice ( 0 , ospan . Length ) . CopyTo ( ospan ) ;
1686
+ ResizeTo ( tensor . AsReadOnlyTensorSpan ( ) , destination ) ;
1710
1687
}
1711
1688
1712
1689
/// <summary>
@@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest
1717
1694
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
1718
1695
public static void ResizeTo < T > ( scoped in TensorSpan < T > tensor , in TensorSpan < T > destination )
1719
1696
{
1720
- ReadOnlySpan < T > span = MemoryMarshal . CreateSpan ( ref tensor . _reference , ( int ) tensor . _shape . LinearLength ) ;
1721
- Span < T > ospan = MemoryMarshal . CreateSpan ( ref destination . _reference , ( int ) destination . _shape . LinearLength ) ;
1722
- if ( ospan . Length >= span . Length )
1723
- span . CopyTo ( ospan ) ;
1724
- else
1725
- span . Slice ( 0 , ospan . Length ) . CopyTo ( ospan ) ;
1697
+ ResizeTo ( tensor . AsReadOnlyTensorSpan ( ) , destination ) ;
1726
1698
}
1727
1699
1728
1700
/// <summary>
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
1890
1862
/// <param name="dimension">The axis to split on.</param>
1891
1863
public static Tensor < T > [ ] Split < T > ( scoped in ReadOnlyTensorSpan < T > tensor , int splitCount , nint dimension )
1892
1864
{
1865
+ if ( dimension < 0 || dimension >= tensor . Rank )
1866
+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
1893
1867
if ( tensor . Lengths [ ( int ) dimension ] % splitCount != 0 )
1894
1868
ThrowHelper . ThrowArgument_SplitNotSplitEvenly ( ) ;
1895
1869
@@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
2221
2195
ThrowHelper . ThrowArgument_StackShapesNotSame ( ) ;
2222
2196
}
2223
2197
2224
- if ( dimension < 0 )
2225
- dimension = tensors [ 0 ] . Rank - dimension ;
2198
+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2199
+ // with our call to Unsqueeze.
2200
+ if ( dimension < 0 || dimension > tensors [ 0 ] . Rank )
2201
+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
2226
2202
2227
2203
Tensor < T > [ ] outputs = new Tensor < T > [ tensors . Length ] ;
2228
2204
for ( int i = 0 ; i < tensors . Length ; i ++ )
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
2259
2235
ThrowHelper . ThrowArgument_StackShapesNotSame ( ) ;
2260
2236
}
2261
2237
2262
- if ( dimension < 0 )
2263
- dimension = tensors [ 0 ] . Rank - dimension ;
2238
+ // We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2239
+ // with our call to Unsqueeze.
2240
+ if ( dimension < 0 || dimension > tensors [ 0 ] . Rank )
2241
+ ThrowHelper . ThrowArgument_AxisLargerThanRank ( ) ;
2264
2242
2265
2243
Tensor < T > [ ] outputs = new Tensor < T > [ tensors . Length ] ;
2266
2244
for ( int i = 0 ; i < tensors . Length ; i ++ )
0 commit comments