@@ -287,6 +287,10 @@ internal static bool CallTorchCudaIsAvailable()
287
287
return THSTorchCuda_is_available ( ) ;
288
288
}
289
289
290
+ /// <summary>
291
+ /// Returns a bool indicating if CUDA is currently available.
292
+ /// </summary>
293
+ /// <returns></returns>
290
294
public static bool is_available ( )
291
295
{
292
296
TryInitializeDeviceType ( DeviceType . CUDA ) ;
@@ -305,12 +309,57 @@ public static bool is_cudnn_available()
305
309
[ DllImport ( "LibTorchSharp" ) ]
306
310
private static extern int THSTorchCuda_device_count ( ) ;
307
311
312
+ /// <summary>
313
+ /// Returns the number of GPUs available.
314
+ /// </summary>
315
+ /// <returns></returns>
308
316
public static int device_count ( )
309
317
{
310
318
TryInitializeDeviceType ( DeviceType . CUDA ) ;
311
319
return THSTorchCuda_device_count ( ) ;
312
320
}
313
321
322
+ [ DllImport ( "LibTorchSharp" ) ]
323
+ private static extern void THSCuda_manual_seed ( long seed ) ;
324
+
325
+ /// <summary>
326
+ /// Sets the seed for generating random numbers for the current GPU.
327
+ /// It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
328
+ /// </summary>
329
+ /// <param name="seed">The desired seed.</param>
330
+ public static void manual_seed ( long seed )
331
+ {
332
+ TryInitializeDeviceType ( DeviceType . CUDA ) ;
333
+ THSCuda_manual_seed ( seed ) ;
334
+ }
335
+
336
+ [ DllImport ( "LibTorchSharp" ) ]
337
+ private static extern void THSCuda_manual_seed_all ( long seed ) ;
338
+
339
+ /// <summary>
340
+ /// Sets the seed for generating random numbers on all GPUs.
341
+ /// It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
342
+ /// </summary>
343
+ /// <param name="seed"></param>
344
+ public static void manual_seed_all ( long seed )
345
+ {
346
+ TryInitializeDeviceType ( DeviceType . CUDA ) ;
347
+ THSCuda_manual_seed_all ( seed ) ;
348
+ }
349
+
350
+ [ DllImport ( "LibTorchSharp" ) ]
351
+ private static extern void THSCuda_synchronize ( long device_index ) ;
352
+
353
+ /// <summary>
354
+ /// Waits for all kernels in all streams on a CUDA device to complete.
355
+ /// </summary>
356
+ /// <param name="seed">Device for which to synchronize.
357
+ /// It uses the current device, given by current_device(), if a device is not provided.</param>
358
+ public static void synchronize ( long seed = - 1L )
359
+ {
360
+ TryInitializeDeviceType ( DeviceType . CUDA ) ;
361
+ THSCuda_synchronize ( seed ) ;
362
+ }
314
363
}
315
364
316
365
[ DllImport ( "LibTorchSharp" ) ]
0 commit comments