Skip to content

Commit d7fc585

Browse files
authored
Merge pull request dotnet#26 from interesaaat/LibTorchSharpFirstTest
move mul and pow to scalar. Few other fixes.
2 parents bb13e5e + 68930b3 commit d7fc585

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

Examples/MNIST.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ private static void Train(
103103

104104
if (batchId % _logInterval == 0)
105105
{
106-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.DataItem<float>()}");
106+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.Item<float>()}");
107107
}
108108

109109
batchId++;
@@ -130,11 +130,11 @@ private static void Test(
130130
using (var prediction = model.Forward(data))
131131
using (var output = loss(prediction, target))
132132
{
133-
testLoss += output.DataItem<float>();
133+
testLoss += output.Item<float>();
134134

135135
var pred = output.Argmax(1);
136136

137-
correct += pred.Eq(target).Sum().DataItem<int>(); // Memory leak here
137+
correct += pred.Eq(target).Sum().Item<int>(); // Memory leak here
138138

139139
data.Dispose();
140140
target.Dispose();

Test/TorchSharp/TorchSharp.cs

+16-3
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,20 @@ public void ScoreModelCheckOutput()
292292
}
293293

294294
[TestMethod]
295-
public void CreateLinear()
295+
public void TestTensorToScalarMultiplication()
296+
{
297+
using (var tensor = FloatTensor.Ones(new long[] { 2, 2 }))
298+
{
299+
var neg = tensor * (-1).ToScalar();
300+
foreach (var val in neg.Data<float>())
301+
{
302+
Assert.AreEqual(val, -1.0);
303+
}
304+
}
305+
}
306+
307+
[TestMethod]
308+
public void TestCreateLinear()
296309
{
297310
var lin = NN.Module.Linear(1000, 100);
298311
Assert.IsNotNull(lin);
@@ -613,7 +626,7 @@ public void TestMul()
613626
{
614627
var x = FloatTensor.Ones(new long[] { 100, 100 });
615628

616-
var y = x.Mul(0.5f);
629+
var y = x.Mul(0.5f.ToScalar());
617630

618631
var ydata = y.Data<float>();
619632
var xdata = x.Data<float>();
@@ -683,7 +696,7 @@ public void TestTraining()
683696
var x = FloatTensor.RandomN(new long[] { 64, 1000 }, device: "cpu:0");
684697
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
685698

686-
float learning_rate = 0.00004f;
699+
Scalar learning_rate = 0.00004f.ToScalar();
687700
float prevLoss = float.MaxValue;
688701
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
689702

TorchSharp/Tensor/TorchTensor.cs

+11-6
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,11 @@ public void MulInPlace(TorchTensor target)
540540
}
541541

542542
[DllImport("libTorchSharp")]
543-
extern static IntPtr THSTensor_mulS(IntPtr src, float scalar);
543+
extern static IntPtr THSTensor_mulS(IntPtr src, IntPtr scalar);
544544

545-
public TorchTensor Mul(float scalar)
545+
public TorchTensor Mul(Scalar scalar)
546546
{
547-
return new TorchTensor(THSTensor_mulS(handle, scalar));
547+
return new TorchTensor(THSTensor_mulS(handle, scalar.Handle));
548548
}
549549

550550
[DllImport("libTorchSharp")]
@@ -556,11 +556,11 @@ public TorchTensor Norm(int dimension, bool KeepDimension = false)
556556
}
557557

558558
[DllImport("libTorchSharp")]
559-
extern static IntPtr THSTensor_pow(IntPtr src, float scalar);
559+
extern static IntPtr THSTensor_pow(IntPtr src, IntPtr scalar);
560560

561-
public TorchTensor Pow(float scalar)
561+
public TorchTensor Pow(Scalar scalar)
562562
{
563-
return new TorchTensor(THSTensor_pow(handle, scalar));
563+
return new TorchTensor(THSTensor_pow(handle, scalar.Handle));
564564
}
565565

566566
[DllImport("libTorchSharp")]
@@ -621,6 +621,11 @@ public TorchTensor Sum(long[] dimensions, bool keepDimension = false)
621621
return left.Mul(right);
622622
}
623623

624+
public static TorchTensor operator *(TorchTensor left, Scalar right)
625+
{
626+
return left.Mul(right);
627+
}
628+
624629
public static TorchTensor operator -(TorchTensor left, TorchTensor right)
625630
{
626631
return left.Sub(right);

TorchSharp/Torch.cs

-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
namespace TorchSharp
66
{
7-
using Debug = System.Diagnostics.Debug;
8-
97
public static class Torch
108
{
119
[DllImport("libTorchSharp")]

0 commit comments

Comments
 (0)