Skip to content

Commit 418bd16

Browse files
Added a number of vision classes and functionality.
Duplicated more of the Tensor methods into 'torch'
1 parent 732e53d commit 418bd16

30 files changed

+745
-83
lines changed

src/FSharp.Examples/FSharp.Examples.fsproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
<ItemGroup>
2323
<PackageReference Include="SharpZipLib" Version="1.3.1" />
24-
<PackageReference Include="System.Memory" Version="4.5.1" />
24+
<PackageReference Include="System.Memory" Version="4.5.3" />
2525
</ItemGroup>
2626

2727
<ItemGroup>

src/TorchSharp/NN/Vision.cs

+45-2
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,52 @@ public static partial class functional
7575
/// <returns></returns>
7676
static public Tensor pad(Tensor input, long[] pad, PaddingModes mode = PaddingModes.Constant, double value = 0)
7777
{
78+
//
79+
// The Pytorch documentation does not cover what is actually happening in the native code, as far as
80+
// the ordering of padding elements goes. This code converts from the documented order, to the actual.
81+
// See: https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.functional.pad
82+
//
83+
long[] correctedPad;
84+
85+
switch (pad.Length) {
86+
case 1:
87+
correctedPad = new long[] { pad[0], pad[0], pad[0], pad[0] };
88+
break;
89+
case 2:
90+
correctedPad = new long[] { pad[0], pad[0], pad[1], pad[1] };
91+
break;
92+
case 4:
93+
correctedPad = new long[] { pad[0], pad[2], pad[1], pad[3] };
94+
break;
95+
default:
96+
correctedPad = pad;
97+
break;
98+
}
99+
100+
unsafe {
101+
fixed (long* psize = correctedPad) {
102+
var res = THSNN_pad(input.Handle, (IntPtr)psize, 4, (byte)mode, value);
103+
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
104+
return new Tensor(res);
105+
}
106+
}
107+
}
108+
109+
/// <summary>
110+
/// Pads tensor.
111+
/// </summary>
112+
/// <param name="input">N-dimensional tensor</param>
113+
/// <param name="pad">A single padding size, used for all edges.</param>
114+
/// <param name="mode">'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant'</param>
115+
/// <param name="value">Fill value for 'constant' padding. Default: 0</param>
116+
/// <returns></returns>
117+
static public Tensor pad(Tensor input, long pad, PaddingModes mode = PaddingModes.Constant, double value = 0)
118+
{
119+
long[] correctedPad = new long[] { pad, pad, pad, pad };
120+
78121
unsafe {
79-
fixed (long* psize = pad) {
80-
var res = THSNN_pad(input.Handle, (IntPtr)psize, pad.Length, (byte)mode, value);
122+
fixed (long* psize = correctedPad) {
123+
var res = THSNN_pad(input.Handle, (IntPtr)psize, 4, (byte)mode, value);
81124
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
82125
return new Tensor(res);
83126
}

src/TorchSharp/Tensor/Tensor.Factories.cs

+21-37
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ static public Tensor zeros(long dim0, long dim1, long dim2, long dim3, torch.Sca
163163
return zeros(new long[] { dim0, dim1, dim2, dim3 }, dtype, device, requiresGrad);
164164
}
165165

166+
/// <summary>
167+
/// Returns a tensor filled with the scalar value 0, with the same size as input.
168+
/// </summary>
169+
public static Tensor zeros_like(Tensor input, ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false) => input.zeros_like(dtype, device, requiresGrad);
170+
166171

167172
// ones()
168173

@@ -226,6 +231,11 @@ static public Tensor ones(long dim0, long dim1, long dim2, long dim3, torch.Scal
226231
return ones(new long[] { dim0, dim1, dim2, dim3 }, dtype, device, requiresGrad);
227232
}
228233

234+
/// <summary>
235+
/// Returns a tensor filled with the scalar value 1, with the same size as input.
236+
/// </summary>
237+
public static Tensor ones_like(Tensor input, ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false) => input.ones_like(dtype, device, requiresGrad);
238+
229239

230240
// empty()
231241

@@ -289,6 +299,12 @@ static public Tensor empty(long dim0, long dim1, long dim2, long dim3, torch.Sca
289299
return empty(new long[] { dim0, dim1, dim2, dim3 }, dtype, device, requiresGrad);
290300
}
291301

302+
/// <summary>
303+
/// Returns a tensor filled with uninitialized data, with the same size as input.
304+
/// </summary>
305+
public static Tensor empty_like(Tensor input, ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false) => input.empty_like(dtype, device, requiresGrad);
306+
307+
292308
[DllImport("LibTorchSharp")]
293309
extern static IntPtr THSTensor_empty_strided(IntPtr psizes, int sz_length, IntPtr pstrides, int str_length, sbyte scalarType, int deviceType, int deviceIndex, bool requiresGrad);
294310

@@ -381,7 +397,10 @@ static public Tensor full(long dim0, long dim1, long dim2, long dim3, Scalar val
381397
return full(new long[] { dim0, dim1, dim2, dim3 }, value, dtype, device, requiresGrad);
382398
}
383399

384-
400+
/// <summary>
401+
/// Returns a tensor with the same size as input filled with 'value.'
402+
/// </summary>
403+
static public Tensor full_like(Tensor input, Scalar value, ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false) => input.full_like(value, dtype, device, requiresGrad);
385404

386405

387406
[DllImport("LibTorchSharp")]
@@ -737,7 +756,7 @@ public static Tensor tensor(long scalar, torch.ScalarType? dtype = null, torch.D
737756
public static Tensor tensor(float scalar, torch.Device device = null, bool requiresGrad = false)
738757
{
739758
device = torch.InitializeDevice(device);
740-
var handle = THSTensor_newFloat16Scalar(scalar, (int)device.type, device.index, requiresGrad);
759+
var handle = THSTensor_newFloat32Scalar(scalar, (int)device.type, device.index, requiresGrad);
741760
if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
742761
return new Tensor(handle);
743762
}
@@ -776,23 +795,6 @@ public static Tensor tensor((float Real, float Imaginary) scalar, torch.ScalarTy
776795
return tensor;
777796
}
778797

779-
/// <summary>
780-
/// Create a scalar tensor from a single value
781-
/// </summary>
782-
public static Tensor tensor(float real, float imaginary = 0.0f, torch.ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false)
783-
{
784-
device = torch.InitializeDevice(device);
785-
var handle = THSTensor_newComplexFloat32Scalar(real, imaginary, (int)device.type, device.index, requiresGrad);
786-
if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
787-
var tensor = new Tensor(handle);
788-
if (device is not null) {
789-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
790-
} else if (dtype.HasValue) {
791-
tensor.to_type(dtype.Value);
792-
}
793-
return tensor;
794-
}
795-
796798
/// <summary>
797799
/// Create a scalar tensor from a single value
798800
/// </summary>
@@ -810,24 +812,6 @@ public static Tensor tensor(System.Numerics.Complex scalar, torch.ScalarType? dt
810812
return tensor;
811813
}
812814

813-
/// <summary>
814-
/// Create a scalar tensor from a single value
815-
/// </summary>
816-
public static Tensor tensor(double real, double imaginary = 0.0f, torch.ScalarType? dtype = null, torch.Device device = null, bool requiresGrad = false)
817-
{
818-
device = torch.InitializeDevice(device);
819-
var handle = THSTensor_newComplexFloat64Scalar(real, imaginary, (int)device.type, device.index, requiresGrad);
820-
if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
821-
var tensor = new Tensor(handle);
822-
if (device is not null) {
823-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
824-
} else if (dtype.HasValue) {
825-
tensor.to_type(dtype.Value);
826-
}
827-
return tensor;
828-
}
829-
830-
831815
/// <summary>
832816
/// Create a tensor from an array of values, shaping it based on the shape passed in.
833817
/// </summary>

src/TorchSharp/Tensor/Tensor.Math.cs

+53
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,11 @@ public Tensor xlogy_(Scalar y)
17911791
return left.sub(right);
17921792
}
17931793

1794+
public static Tensor operator -(Scalar left, Tensor right)
1795+
{
1796+
return right.negative().add(left);
1797+
}
1798+
17941799
public static Tensor operator /(Tensor left, Tensor right)
17951800
{
17961801
return left.div(right);
@@ -1801,6 +1806,11 @@ public Tensor xlogy_(Scalar y)
18011806
return left.div(right);
18021807
}
18031808

1809+
public static Tensor operator /(Scalar left, Tensor right)
1810+
{
1811+
return right.reciprocal().mul(left);
1812+
}
1813+
18041814
public static Tensor operator %(Tensor left, Tensor right)
18051815
{
18061816
return left.remainder(right);
@@ -2083,6 +2093,33 @@ public Tensor xlogy_(Scalar y)
20832093
/// <returns></returns>
20842094
public static Tensor divide_(Tensor left, Scalar right) => left.div_(right);
20852095

2096+
[DllImport("LibTorchSharp")]
2097+
extern static IntPtr THSTensor_einsum([MarshalAs(UnmanagedType.LPStr)] string location, IntPtr tensors, int len);
2098+
2099+
/// <summary>
2100+
/// Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.
2101+
/// </summary>
2102+
/// <param name="equation">The subscripts for the Einstein summation.</param>
2103+
/// <param name="tensors">The operands to compute the Einstein sum of.</param>
2104+
/// <remarks>
2105+
/// Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them in a short-hand format based on the
2106+
/// Einstein summation convention, given by equation.The details of this format are described below, but the general idea is to label every dimension
2107+
/// of the input operands with some subscript and define which subscripts are part of the output. The output is then computed by summing the product
2108+
/// of the elements of the operands along the dimensions whose subscripts are not part of the output.For example, matrix multiplication can be computed
2109+
/// using einsum as torch.einsum(“ij,jk->ik”, A, B). Here, j is the summation subscript and i and k the output subscripts(see section below for more details on why).
2110+
/// </remarks>
2111+
/// <returns></returns>
2112+
public static Tensor einsum(string equation, params Tensor[] tensors)
2113+
{
2114+
using (var parray = new PinnedArray<IntPtr>()) {
2115+
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
2116+
2117+
var res = THSTensor_einsum(equation, tensorsRef, parray.Array.Length);
2118+
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
2119+
return new Tensor(res);
2120+
}
2121+
}
2122+
20862123
/// <summary>
20872124
/// Returns a new tensor with the exponential of the elements of the input tensor input.
20882125
/// </summary>
@@ -2278,6 +2315,22 @@ public Tensor xlogy_(Scalar y)
22782315
/// <returns></returns>
22792316
public static Tensor logit(Tensor input, double? eps = null) => input.logit(eps);
22802317

2318+
public static Tensor max(Tensor input) => input.max();
2319+
2320+
static public Tensor max(Tensor input, Tensor other) => input.max(other);
2321+
2322+
static public (Tensor values, Tensor indexes) max(Tensor input, long dimension, bool keepDim = false) => input.max(dimension, keepDim);
2323+
2324+
public static Tensor mean(Tensor input) => input.mean();
2325+
2326+
public static Tensor mean(Tensor input, long[] dimensions, bool keepDimension = false, ScalarType? type = null) => input.mean(dimensions, keepDimension, type);
2327+
2328+
public static Tensor min(Tensor input) => input.min();
2329+
2330+
static public Tensor min(Tensor input, Tensor other) => input.min(other);
2331+
2332+
static public (Tensor values, Tensor indexes) min(Tensor input, long dimension, bool keepDim = false) => input.min(dimension, keepDim);
2333+
22812334
/// <summary>
22822335
/// Divides each element of the input by the corresponding element of other.
22832336
/// </summary>

src/TorchSharp/Tensor/Tensor.cs

+9-18
Original file line numberDiff line numberDiff line change
@@ -3916,7 +3916,7 @@ public Tensor arange_out(Scalar start, Scalar stop, Scalar step)
39163916
/// Returns a view of the original tensor with its dimensions permuted.
39173917
/// </summary>
39183918
/// <param name="permutation">The desired ordering of dimensions</param>
3919-
public Tensor permute(long[] permutation)
3919+
public Tensor permute(params long[] permutation)
39203920
{
39213921
unsafe {
39223922
fixed (long* pPermutation = permutation) {
@@ -4318,20 +4318,6 @@ public Tensor where(Tensor condition, Tensor other)
43184318
return new Tensor(res);
43194319
}
43204320

4321-
[DllImport("LibTorchSharp")]
4322-
extern static IntPtr THSTensor_einsum([MarshalAs(UnmanagedType.LPStr)] string location, IntPtr tensors, int len);
4323-
4324-
public static Tensor einsum(string equation, params Tensor[] tensors)
4325-
{
4326-
using (var parray = new PinnedArray<IntPtr>()) {
4327-
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
4328-
4329-
var res = THSTensor_einsum(equation, tensorsRef, parray.Array.Length);
4330-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
4331-
return new Tensor(res);
4332-
}
4333-
}
4334-
43354321

43364322
// Operators overloading
43374323

@@ -4759,6 +4745,11 @@ public static implicit operator TensorIndex(long value)
47594745
{
47604746
return TensorIndex.Single(value);
47614747
}
4748+
4749+
public static implicit operator TensorIndex(System.Range value)
4750+
{
4751+
return TensorIndex.Slice(value.Start.Value, value.End.Value);
4752+
}
47624753
}
47634754

47644755
/// <summary>
@@ -4784,7 +4775,7 @@ public enum ScalarType : sbyte
47844775
BFloat16 = 15
47854776
}
47864777

4787-
static bool is_integral(ScalarType type)
4778+
public static bool is_integral(ScalarType type)
47884779
{
47894780
switch (type) {
47904781
case ScalarType.Byte:
@@ -4799,7 +4790,7 @@ static bool is_integral(ScalarType type)
47994790
}
48004791
}
48014792

4802-
static bool is_floating_point(ScalarType type)
4793+
public static bool is_floating_point(ScalarType type)
48034794
{
48044795
switch (type) {
48054796
case ScalarType.BFloat16:
@@ -4812,7 +4803,7 @@ static bool is_floating_point(ScalarType type)
48124803
}
48134804
}
48144805

4815-
static bool is_complex(ScalarType type)
4806+
public static bool is_complex(ScalarType type)
48164807
{
48174808
switch (type) {
48184809
case ScalarType.ComplexFloat32:

src/TorchSharp/Tensor/Tensor.torch.cs

+29-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ public static Tensor stack(IList<Tensor> tensors, long dimension)
5656
}
5757
}
5858

59+
/// <summary>
60+
/// Returns a view of the original tensor with its dimensions permuted.
61+
/// </summary>
62+
/// <param name="input">The input tensor.</param>
63+
/// <param name="permutation">The desired ordering of dimensions</param>
64+
static public Tensor permute(Tensor input, params long[] permutation) => input.permute(permutation);
65+
5966
[DllImport("LibTorchSharp")]
6067
extern static IntPtr THSTensor_hstack(IntPtr tensor, int len);
6168

@@ -153,5 +160,26 @@ public static Tensor dstack(IList<Tensor> tensors)
153160
return new Tensor(res);
154161
}
155162
}
163+
164+
static public Tensor clamp(Tensor input, Scalar min, Scalar max) => input.clamp(min, max);
165+
166+
static public Tensor clamp_(Tensor input, Scalar min, Scalar max) => input.clamp_(min, max);
167+
168+
static public Tensor clamp_max(Tensor input, Scalar max) => input.clamp_max(max);
169+
170+
static public Tensor clamp_max_(Tensor input, Scalar max) => input.clamp_max(max);
171+
172+
static public Tensor clamp_min(Tensor input, Scalar min) => input.clamp_min(min);
173+
174+
static public Tensor clamp_min_(Tensor input, Scalar min) => input.clamp_min(min);
175+
176+
/// <summary>
177+
/// Return a tensor of elements selected from either x or y, depending on condition.
178+
/// </summary>
179+
/// <param name="condition">When true, yield x, otherwise yield y.</param>
180+
/// <param name="x">Values selected at indices where condition is true</param>
181+
/// <param name="y">Values selected at indices where condition is false</param>
182+
/// <returns></returns>
183+
static public Tensor where(Tensor condition, Tensor x, Tensor y) => x.where(condition, y);
184+
}
156185
}
157-
}

src/TorchSharp/Tensor/TensorExtensionMethods.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ public static Tensor crop(this Tensor image, int top, int left, int height, int
204204

205205
var slice = image.index(TensorIndex.Ellipsis, TensorIndex.Slice(Math.Max(top, 0), bottom), TensorIndex.Slice(Math.Max(left, 0), right));
206206

207-
// Note: according to the documentation, it should be LTRB, but that generates the wrong result. Here, we use LRTB.
208-
var padding_ltrb = new long[] { Math.Max(-left, 0), Math.Max(right - w, 0), Math.Max(-top, 0), Math.Max(bottom - h, 0) };
207+
var padding_ltrb = new long[] { Math.Max(-left, 0), Math.Max(-top, 0), Math.Max(right - w, 0), Math.Max(bottom - h, 0) };
209208

210209
return TorchSharp.torch.nn.functional.pad(slice, padding_ltrb);
211210
}

src/TorchSharp/TorchSharp.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
</ItemGroup>
2424

2525
<ItemGroup>
26-
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
26+
<PackageReference Include="System.Memory" Version="4.5.3" />
2727
</ItemGroup>
2828

2929
<ItemGroup>

0 commit comments

Comments
 (0)