Skip to content

Commit 0c6ac9c

Browse files
Minor tensor fixes (#115125)
* minor tensor fixes * Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs Co-authored-by: Copilot <[email protected]> * Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs Co-authored-by: Copilot <[email protected]> * pr updates * fixing reshape * reverting co-pilot suggestions * added some comments to clarify dimension * add checked * fixes from pr comments --------- Co-authored-by: Copilot <[email protected]>
1 parent 458b119 commit 0c6ac9c

File tree

4 files changed

+126
-97
lines changed

4 files changed

+126
-97
lines changed

src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@
177177
<data name="ThrowArgument_InPlaceInvalidShape" xml:space="preserve">
178178
<value>In place operations require the same shape for both tensors</value>
179179
</data>
180-
<data name="ThrowArgument_InvalidAxis" xml:space="preserve">
181-
<value>Invalid axis provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
180+
<data name="ThrowArgument_InvalidDimension" xml:space="preserve">
181+
<value>Invalid dimension provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
182182
</data>
183183
<data name="ThrowArgument_InvalidConcatenateShape" xml:space="preserve">
184184
<value>The tensors must have the same shape, except in the dimension corresponding to axis.</value>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs

+62-84
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
133133
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();
134134

135135
if (dimension < -1 || dimension > tensors[0].Rank)
136-
ThrowHelper.ThrowArgument_InvalidAxis();
136+
ThrowHelper.ThrowArgument_InvalidDimension();
137137

138-
// Calculate total space needed.
139-
nint totalLength = 0;
140-
for (int i = 0; i < tensors.Length; i++)
141-
totalLength += tensors[i].FlattenedLength;
138+
Tensor<T> tensor;
142139

143-
nint sumOfAxis = 0;
144140
// If axis != -1, make sure all dimensions except the one to concatenate on match.
145141
if (dimension != -1)
146142
{
147-
sumOfAxis = tensors[0].Lengths[dimension];
143+
nint sumOfAxis = tensors[0].Lengths[dimension];
148144
for (int i = 1; i < tensors.Length; i++)
149145
{
150146
if (tensors[0].Rank != tensors[i].Rank)
@@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
157153
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
158154
}
159155
}
160-
sumOfAxis += tensors[i].Lengths[dimension];
156+
checked
157+
{
158+
sumOfAxis += tensors[i].Lengths[dimension];
159+
}
161160
}
162-
}
163161

164-
Tensor<T> tensor;
165-
if (dimension == -1)
166-
{
167-
tensor = Tensor.Create<T>([totalLength]);
168-
}
169-
else
170-
{
171162
nint[] lengths = new nint[tensors[0].Rank];
172163
tensors[0].Lengths.CopyTo(lengths);
173164
lengths[dimension] = sumOfAxis;
174165
tensor = Tensor.Create<T>(lengths);
175166
}
167+
else
168+
{
169+
// Calculate total space needed.
170+
nint totalLength = 0;
171+
for (int i = 0; i < tensors.Length; i++)
172+
{
173+
checked
174+
{
175+
totalLength += tensors[i].FlattenedLength;
176+
}
177+
}
178+
179+
tensor = Tensor.Create<T>([totalLength]);
180+
}
176181

177182
ConcatenateOnDimension(dimension, tensors, tensor);
178183
return tensor;
@@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
201206
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();
202207

203208
if (dimension < -1 || dimension > tensors[0].Rank)
204-
ThrowHelper.ThrowArgument_InvalidAxis();
209+
ThrowHelper.ThrowArgument_InvalidDimension();
205210

206211
// Calculate total space needed.
207212
nint totalLength = 0;
@@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
212217
if (dimension != -1)
213218
{
214219
nint sumOfAxis = tensors[0].Lengths[dimension];
220+
int rank = tensors[0].Rank;
215221
for (int i = 1; i < tensors.Length; i++)
216222
{
217-
if (tensors[0].Rank != tensors[i].Rank)
223+
if (rank != tensors[i].Rank)
218224
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
219-
for (int j = 0; j < tensors[0].Rank; j++)
225+
for (int j = 0; j < rank; j++)
220226
{
221227
if (j != dimension)
222228
{
@@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
228234
}
229235

230236
// Make sure the destination tensor has the correct shape.
231-
nint[] lengths = new nint[tensors[0].Rank];
237+
nint[] lengths = new nint[rank];
232238
tensors[0].Lengths.CopyTo(lengths);
233239
lengths[dimension] = sumOfAxis;
234240

@@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint
339345
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns>
340346
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, bool pinned = false)
341347
{
348+
T[] array = enumerable.ToArray();
349+
342350
if (pinned)
343351
{
344-
T[] array = enumerable.ToArray();
345-
346352
Tensor<T> tensor = CreateUninitialized<T>([array.Length], pinned);
347353
array.CopyTo(tensor._values);
348354

349355
return tensor;
350356
}
351357
else
352358
{
353-
T[] array = enumerable.ToArray();
354359
return Create(array);
355360
}
356361
}
@@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan
364369
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns>
365370
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned = false)
366371
{
372+
T[] array = enumerable.ToArray();
373+
367374
if (pinned)
368375
{
369-
T[] array = enumerable.ToArray();
370-
371376
Tensor<T> tensor = CreateUninitialized<T>(lengths, strides, pinned);
372377
array.CopyTo(tensor._values);
373378

374379
return tensor;
375380
}
376381
else
377382
{
378-
T[] array = enumerable.ToArray();
379383
return Create(array, lengths, strides);
380384
}
381385
}
@@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y)
620624
/// <param name="value">Value to update in the <paramref name="tensor"/>.</param>
621625
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, T value)
622626
{
623-
if (filter.Lengths.Length != tensor.Lengths.Length)
624-
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
625-
626-
Span<T> srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
627-
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
628-
629-
for (int i = 0; i < filterSpan.Length; i++)
630-
{
631-
if (filterSpan[i])
632-
{
633-
srcSpan[i] = value;
634-
}
635-
}
636-
627+
TensorOperation.ValidateCompatibility(filter, tensor);
628+
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, value, tensor);
637629
return ref tensor;
638630
}
639631

@@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T>
646638
/// <param name="values">Values to update in the <paramref name="tensor"/>.</param>
647639
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, scoped in ReadOnlyTensorSpan<T> values)
648640
{
649-
if (filter.Lengths.Length != tensor.Lengths.Length)
650-
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
651-
if (values.Rank != 1)
652-
ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values));
653-
654-
Span<T> dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
655-
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
656-
Span<T> valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength);
657-
658-
int index = 0;
659-
for (int i = 0; i < filterSpan.Length; i++)
660-
{
661-
if (filterSpan[i])
662-
{
663-
dstSpan[i] = valuesSpan[index++];
664-
}
665-
}
666-
641+
TensorOperation.ValidateCompatibility(filter, values, tensor);
642+
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, values, tensor);
667643
return ref tensor;
668644
}
669645
#endregion
@@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
14091385
}
14101386
else
14111387
{
1388+
if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length)
1389+
ThrowHelper.ThrowArgument_PermuteAxisOrder();
1390+
14121391
scoped Span<nint> newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> lengthsRentedBuffer);
14131392
scoped Span<nint> newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> stridesRentedBuffer);
14141393
scoped Span<int> newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<int> linearOrderRentedBuffer);
@@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
14261405
}
14271406
else
14281407
{
1429-
if (dimensions.Length != tensor.Lengths.Length)
1430-
ThrowHelper.ThrowArgument_PermuteAxisOrder();
1431-
14321408
for (int i = 0; i < dimensions.Length; i++)
14331409
{
1410+
if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0)
1411+
{
1412+
ThrowHelper.ThrowArgument_InvalidDimension();
1413+
}
14341414
newLengths[i] = tensor.Lengths[dimensions[i]];
14351415
newStrides[i] = tensor.Strides[dimensions[i]];
14361416
newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]];
@@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
14671447

14681448
nint[] newLengths = lengths.ToArray();
14691449
// Calculate wildcard info.
1470-
if (lengths.Contains(-1))
1450+
int wildcardIndex = lengths.IndexOf(-1);
1451+
if (wildcardIndex >= 0)
14711452
{
14721453
if (lengths.Count(-1) > 1)
14731454
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
14791460
tempTotal /= lengths[i];
14801461
}
14811462
}
1482-
newLengths[lengths.IndexOf(-1)] = tempTotal;
1463+
newLengths[wildcardIndex] = tempTotal;
14831464
}
14841465

14851466
nint tempLinear = TensorPrimitives.Product(newLengths);
@@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
15381519
}
15391520

15401521
nint[] newLengths = lengths.ToArray();
1541-
// Calculate wildcard info.
1542-
if (lengths.Contains(-1))
1522+
int wildcardIndex = lengths.IndexOf(-1);
1523+
if (wildcardIndex >= 0)
15431524
{
15441525
if (lengths.Count(-1) > 1)
15451526
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
15511532
tempTotal /= lengths[i];
15521533
}
15531534
}
1554-
newLengths[lengths.IndexOf(-1)] = tempTotal;
1535+
newLengths[wildcardIndex] = tempTotal;
15551536

15561537
}
15571538

@@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
16151596

16161597
nint[] newLengths = lengths.ToArray();
16171598
// Calculate wildcard info.
1618-
if (lengths.Contains(-1))
1599+
int wildcardIndex = lengths.IndexOf(-1);
1600+
if (wildcardIndex >= 0)
16191601
{
16201602
if (lengths.Count(-1) > 1)
16211603
ThrowHelper.ThrowArgument_OnlyOneWildcard();
@@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
16271609
tempTotal /= lengths[i];
16281610
}
16291611
}
1630-
newLengths[lengths.IndexOf(-1)] = tempTotal;
1612+
newLengths[wildcardIndex] = tempTotal;
16311613

16321614
}
16331615

@@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
17011683
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
17021684
public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> destination)
17031685
{
1704-
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
1705-
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
1706-
if (ospan.Length >= span.Length)
1707-
span.CopyTo(ospan);
1708-
else
1709-
span.Slice(0, ospan.Length).CopyTo(ospan);
1686+
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
17101687
}
17111688

17121689
/// <summary>
@@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest
17171694
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
17181695
public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T> destination)
17191696
{
1720-
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
1721-
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
1722-
if (ospan.Length >= span.Length)
1723-
span.CopyTo(ospan);
1724-
else
1725-
span.Slice(0, ospan.Length).CopyTo(ospan);
1697+
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
17261698
}
17271699

17281700
/// <summary>
@@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
18901862
/// <param name="dimension">The axis to split on.</param>
18911863
public static Tensor<T>[] Split<T>(scoped in ReadOnlyTensorSpan<T> tensor, int splitCount, nint dimension)
18921864
{
1865+
if (dimension < 0 || dimension >= tensor.Rank)
1866+
ThrowHelper.ThrowArgument_AxisLargerThanRank();
18931867
if (tensor.Lengths[(int)dimension] % splitCount != 0)
18941868
ThrowHelper.ThrowArgument_SplitNotSplitEvenly();
18951869

@@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
22212195
ThrowHelper.ThrowArgument_StackShapesNotSame();
22222196
}
22232197

2224-
if (dimension < 0)
2225-
dimension = tensors[0].Rank - dimension;
2198+
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2199+
// with our call to Unsqueeze.
2200+
if (dimension < 0 || dimension > tensors[0].Rank)
2201+
ThrowHelper.ThrowArgument_AxisLargerThanRank();
22262202

22272203
Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
22282204
for (int i = 0; i < tensors.Length; i++)
@@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
22592235
ThrowHelper.ThrowArgument_StackShapesNotSame();
22602236
}
22612237

2262-
if (dimension < 0)
2263-
dimension = tensors[0].Rank - dimension;
2238+
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
2239+
// with our call to Unsqueeze.
2240+
if (dimension < 0 || dimension > tensors[0].Rank)
2241+
ThrowHelper.ThrowArgument_AxisLargerThanRank();
22642242

22652243
Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
22662244
for (int i = 0; i < tensors.Length; i++)

0 commit comments

Comments
 (0)