Skip to content

Commit d93f3d7

Browse files
Adding more torch.cuda functionality.
Fixing the squeeze() issue.
1 parent ba8ae31 commit d93f3d7

File tree

9 files changed

+113
-7
lines changed

9 files changed

+113
-7
lines changed

src/Native/LibTorchSharp/THSTensor.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,11 @@ void THSTensor_dsplit_with_sizes(
12301230
)
12311231
}
12321232

1233+
Tensor THSTensor_squeeze_no_dim(Tensor tensor)
1234+
{
1235+
CATCH_TENSOR(tensor->squeeze());
1236+
}
1237+
12331238
Tensor THSTensor_squeeze(Tensor tensor, int64_t dim)
12341239
{
12351240
CATCH_TENSOR(tensor->squeeze(dim));

src/Native/LibTorchSharp/THSTensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ EXPORT_API(void) THSTensor_split_with_size(const Tensor tensor, Tensor* (*alloca
10451045
EXPORT_API(void) THSTensor_split_with_sizes(const Tensor tensor, Tensor* (*allocator)(size_t length), const int64_t* sizes, const int length, const int64_t dim);
10461046

10471047
EXPORT_API(Tensor) THSTensor_squeeze(Tensor tensor, int64_t dim);
1048+
EXPORT_API(Tensor) THSTensor_squeeze_no_dim(Tensor tensor);
10481049

10491050
EXPORT_API(Tensor) THSTensor_stack(const Tensor* tensor, const int length, const int64_t dim);
10501051

src/Native/LibTorchSharp/THSTorch.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ Generator THSGenerator_manual_seed(const int64_t seed)
1515
return THSGenerator_default_generator();
1616
}
1717

18+
void THSCuda_manual_seed(const int64_t seed)
19+
{
20+
CATCH(torch::cuda::manual_seed(seed);)
21+
}
22+
23+
void THSCuda_manual_seed_all(const int64_t seed)
24+
{
25+
CATCH(torch::cuda::manual_seed_all(seed);)
26+
}
27+
28+
1829
void THSGenerator_gen_manual_seed(const Generator generator, const int64_t seed)
1930
{
2031
generator->set_current_seed(seed);
@@ -69,6 +80,12 @@ int THSTorchCuda_device_count()
6980
return (int)torch::cuda::device_count();
7081
}
7182

83+
EXPORT_API(void) THSTorchCuda_synchronize(const int64_t device_index)
84+
{
85+
CATCH(torch::cuda::synchronize(device_index);)
86+
}
87+
88+
7289
const char * THSTorch_get_and_reset_last_err()
7390
{
7491
char *tmp = torch_last_err;

src/Native/LibTorchSharp/THSTorch.h

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
// Sets manually the seed.
1111
EXPORT_API(void) THSTorch_manual_seed(const int64_t seed);
12+
EXPORT_API(void) THSCuda_manual_seed(const int64_t seed);
13+
EXPORT_API(void) THSCuda_manual_seed_all(const int64_t seed);
1214

1315
EXPORT_API(Generator) THSGenerator_manual_seed(const int64_t seed);
1416
EXPORT_API(void) THSGenerator_gen_manual_seed(const Generator gen, const int64_t seed);
@@ -24,6 +26,7 @@ EXPORT_API(void) THSGenerator_dispose(const Generator generator);
2426
EXPORT_API(int) THSTorchCuda_is_available();
2527
EXPORT_API(int) THSTorchCuda_cudnn_is_available();
2628
EXPORT_API(int) THSTorchCuda_device_count();
29+
EXPORT_API(void) THSTorchCuda_synchronize(const int64_t device);
2730

2831
// Returns the latest error. This is thread-local.
2932
EXPORT_API(const char *) THSTorch_get_and_reset_last_err();

src/TorchSharp/Tensor/Tensor.cs

+9-6
Original file line numberDiff line numberDiff line change
@@ -1243,14 +1243,17 @@ public Tensor reshape(params long[] shape)
12431243
[DllImport("LibTorchSharp")]
12441244
static extern IntPtr THSTensor_squeeze(IntPtr tensor, long dimension);
12451245

1246+
[DllImport("LibTorchSharp")]
1247+
static extern IntPtr THSTensor_squeeze_no_dim(IntPtr tensor);
1248+
12461249
/// <summary>
12471250
/// Returns a tensor with all the dimensions of input of size 1 removed. When dim is given, a squeeze operation is done only in the given dimension.
12481251
/// </summary>
12491252
/// <param name="dim">If given, the input will be squeezed only in this dimension</param>
12501253
/// <returns></returns>
1251-
public Tensor squeeze(long dim)
1254+
public Tensor squeeze(long? dim = null)
12521255
{
1253-
var res = THSTensor_squeeze(handle, dim);
1256+
var res = dim.HasValue ? THSTensor_squeeze(handle, dim.Value) : THSTensor_squeeze_no_dim(handle);
12541257
if (res == IntPtr.Zero)
12551258
torch.CheckForErrors();
12561259
return new Tensor(res);
@@ -4239,9 +4242,9 @@ public Tensor slice(long dimension, long start, long finish, long step)
42394242
/// Returns a new tensor with a dimension of size one inserted at the specified position.
42404243
/// The returned tensor shares the same underlying data with this tensor.
42414244
/// </summary>
4242-
public Tensor unsqueeze(long dimension)
4245+
public Tensor unsqueeze(long dim)
42434246
{
4244-
var res = THSTensor_unsqueeze(handle, dimension);
4247+
var res = THSTensor_unsqueeze(handle, dim);
42454248
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
42464249
return new Tensor(res);
42474250
}
@@ -4253,9 +4256,9 @@ public Tensor unsqueeze(long dimension)
42534256
/// Returns a new tensor with a dimension of size one inserted at the specified position.
42544257
/// The returned tensor shares the same underlying data with this tensor.
42554258
/// </summary>
4256-
public Tensor unsqueeze_(long dimension)
4259+
public Tensor unsqueeze_(long dim)
42574260
{
4258-
var res = THSTensor_unsqueeze_(handle, dimension);
4261+
var res = THSTensor_unsqueeze_(handle, dim);
42594262
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
42604263
return new Tensor(res);
42614264
}

src/TorchSharp/Tensor/Tensor.torch.cs

+21
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ public static Tensor cat(IList<Tensor> tensors, long dimension)
3737
}
3838
}
3939

40+
/// <summary>
41+
/// Returns a tensor with all the dimensions of input of size 1 removed. When dim is given, a squeeze operation is done only in the given dimension.
42+
/// </summary>
43+
/// <param name="input"></param>
44+
/// <param name="dim">If given, the input will be squeezed only in this dimension</param>
45+
/// <returns></returns>
46+
public static Tensor squeeze(Tensor input, long? dim = null) => input.squeeze(dim);
47+
48+
49+
/// <summary>
50+
/// Returns a new tensor with a dimension of size one inserted at the specified position.
51+
/// The returned tensor shares the same underlying data with this tensor.
52+
/// </summary>
53+
public static Tensor unsqueeze(Tensor input, long dim) => input.unsqueeze(dim);
54+
55+
/// <summary>
56+
/// Returns a new tensor with a dimension of size one inserted at the specified position.
57+
/// The returned tensor shares the same underlying data with this tensor.
58+
/// </summary>
59+
public static Tensor unsqueeze_(Tensor input, long dim) => input.unsqueeze_(dim);
60+
4061
[DllImport("LibTorchSharp")]
4162
extern static IntPtr THSTensor_stack(IntPtr tensor, int len, long dim);
4263

src/TorchSharp/Torch.cs

+49
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ internal static bool CallTorchCudaIsAvailable()
287287
return THSTorchCuda_is_available();
288288
}
289289

290+
/// <summary>
291+
/// Returns a bool indicating if CUDA is currently available.
292+
/// </summary>
293+
/// <returns></returns>
290294
public static bool is_available()
291295
{
292296
TryInitializeDeviceType(DeviceType.CUDA);
@@ -305,12 +309,57 @@ public static bool is_cudnn_available()
305309
[DllImport("LibTorchSharp")]
306310
private static extern int THSTorchCuda_device_count();
307311

312+
/// <summary>
313+
/// Returns the number of GPUs available.
314+
/// </summary>
315+
/// <returns></returns>
308316
public static int device_count()
309317
{
310318
TryInitializeDeviceType(DeviceType.CUDA);
311319
return THSTorchCuda_device_count();
312320
}
313321

322+
[DllImport("LibTorchSharp")]
323+
private static extern void THSCuda_manual_seed(long seed);
324+
325+
/// <summary>
326+
/// Sets the seed for generating random numbers for the current GPU.
327+
/// It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
328+
/// </summary>
329+
/// <param name="seed">The desired seed.</param>
330+
public static void manual_seed(long seed)
331+
{
332+
TryInitializeDeviceType(DeviceType.CUDA);
333+
THSCuda_manual_seed(seed);
334+
}
335+
336+
[DllImport("LibTorchSharp")]
337+
private static extern void THSCuda_manual_seed_all(long seed);
338+
339+
/// <summary>
340+
/// Sets the seed for generating random numbers on all GPUs.
341+
/// It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
342+
/// </summary>
343+
/// <param name="seed"></param>
344+
public static void manual_seed_all(long seed)
345+
{
346+
TryInitializeDeviceType(DeviceType.CUDA);
347+
THSCuda_manual_seed_all(seed);
348+
}
349+
350+
[DllImport("LibTorchSharp")]
351+
private static extern void THSCuda_synchronize(long device_index);
352+
353+
/// <summary>
354+
/// Waits for all kernels in all streams on a CUDA device to complete.
355+
/// </summary>
356+
/// <param name="seed">Device for which to synchronize.
357+
/// It uses the current device, given by current_device(), if a device is not provided.</param>
358+
public static void synchronize(long seed = -1L)
359+
{
360+
TryInitializeDeviceType(DeviceType.CUDA);
361+
THSCuda_synchronize(seed);
362+
}
314363
}
315364

316365
[DllImport("LibTorchSharp")]

src/TorchSharp/TorchVision/Functional.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ private static Tensor HSVtoRGB(Tensor h, Tensor s, Tensor v)
764764
var q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0);
765765
var t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0);
766766

767-
var iunsq = i.unsqueeze(dimension: -3);
767+
var iunsq = i.unsqueeze(dim: -3);
768768
var mask = iunsq == torch.arange(6, device: i.device).view(-1, 1, 1);
769769

770770
var a1 = torch.stack(new Tensor[] { v, q, p, p, t, v }, dimension: -3);

test/TorchSharpTest/TestTorchTensor.cs

+7
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,13 @@ public void SqueezeTest()
41274127
Assert.Equal(2.0f, res[1].ToSingle());
41284128
Assert.Equal(3.1f, res[2].ToSingle());
41294129
}
4130+
// And all dims.
4131+
using (var res = Float32Tensor.from(data).expand(new long[] { 1, 1, 3 }).squeeze()) {
4132+
Assert.Equal(new long[] { 3 }, res.shape);
4133+
Assert.Equal(1.1f, res[0].ToSingle());
4134+
Assert.Equal(2.0f, res[1].ToSingle());
4135+
Assert.Equal(3.1f, res[2].ToSingle());
4136+
}
41304137
}
41314138

41324139
[Fact]

0 commit comments

Comments
 (0)