Skip to content

Commit 5232b4f

Browse files
New examples with save/load are working.
1 parent 2ae62a7 commit 5232b4f

File tree

7 files changed

+191
-99
lines changed

7 files changed

+191
-99
lines changed

docfx/articles/memory.md

+13-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ Two approaches are available for memory management. Technique 1 is the default a
88

99
Note DiffSharp (which uses TorchSharp) relies on techniques 1.
1010

11+
> Most of the examples included will use technique #1, doing frequent explicit calls to GC.Collect() in the training code -- if not after each batch in the training loop, at least after each epoch.
12+
1113
## Technique 1. Implicit disposal using finalizers
1214

1315
In this technique all tensors (CPU and GPU) are implicitly disposed via .NET finalizers.
@@ -21,19 +23,26 @@ This is not yet done when using general tensor operations. It is possible a mor
2123

2224
👎 The .NET GC doesn't know of the memory pressure from CPU tensors, so failure may happen if large tensors can't be allocated
2325

24-
👎 The .NET GC doesn't know of GPU resources
26+
👎 The .NET GC doesn't know of GPU resources.
27+
28+
👎 Native operations that allocate temporaries, whether on CPU or GPU, may fail -- the GC scheme implemented by TorchSharp only works when the allocation is initiated by .NET code.
2529

2630
## Technique 2. Explicit disposal
2731

2832
In this technique specific tensors (CPU and GPU) are explicitly disposed
2933
using `using` in C# or explicit calls to `System.IDisposable.Dispose()`.
3034

31-
👍 control
35+
👍 Specific lifetime management of all resources.
36+
37+
👎 Cumbersome, requiring lots of using statements in your code.
3238

33-
👎 you must know when to dispose
39+
👎 You must know when to dispose.
40+
41+
👎 Temporaries are not covered by this approach, so to maximize the benefit, you may have to store all temporaries to variables and dispose.
3442

3543
> NOTE: Disposing a tensor only releases the underlying storage if this is the last
36-
> live TorchTensor which has a view on that tensor.
44+
> live TorchTensor which has a view on that tensor -- the native runtime does reference counting of tensors.
45+
3746

3847
## Links and resources
3948

docfx/articles/saveload.md

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Saving and Restoring Models
2+
3+
When using PyTorch, the expected pattern to use when saving and later restoring models from disk or other permanent storage media, is to get the model's state and pickle that using the standard Python format.
4+
5+
```Python
6+
torch.save(model.state_dict(), 'model_weights.pth')
7+
```
8+
9+
When restoring the model, you are expected to first create a model of the exact same structure as the original, with random weights, then restore the state:
10+
11+
```Python
12+
model = [...]
13+
model.load_state_dict(torch.load('model_weights.pth'))
14+
```
15+
16+
This presents a couple of problems for a .NET implementation. First, Python pickling is very intimately coupled with Python and its runtime object model. It is a complex format that supports object graphs that form DAGs, and faithfully maintaining all object state.
17+
18+
Second, in order to share models between .NET applications, Python pickling is not necessary, and even for moving model state from Python to .NET, it is overkill. The state of a model is a simple dictionary where the keys are strings and the values are tensors.
19+
20+
Therefore, TorchSharp in its current form, implements its own very simple model serialization format, which allows models originating in either .NET or Python to be loaded using .NET, as long as the model was saved using the special format.
21+
22+
The MNIST and AdversarialExampleGeneration examples in this repo rely on saving and restoring model state -- the latter example relies on a pre-trained model from MNST.
23+
24+
> A future version of TorchSharp may include support for reading and writing Python pickle files directly. There are
25+
26+
## How to use the TorchSharp format
27+
28+
29+
In C#, saving a model looks like this:
30+
31+
```C#
32+
model.save("model_weights.dat");
33+
```
34+
35+
It's important to note that calling 'save' will move the model to the CPU, where it remains after the call. If you need to continue to use the model after saving it, you will have to explicitly move it back:
36+
37+
```C#
38+
model.to(Device.CUDA);
39+
```
40+
41+
And loading it again is done by:
42+
43+
```C#
44+
model = [...];
45+
model.load("model_weights.dat");
46+
```
47+
48+
The model should be created on the CPU before loading weights, then moved to the target device.
49+
50+
If the model starts out in Python, there's a simple script that allows you to use code that is very similar to the Pytorch API to save models to the TorchSharp format. Rather than placing this trivial script in a Python package and publishing it, we choose to just refer you to the script file itself, [exportsd.py](../src/Python/exportsd.py), which has all the necessary code.
51+
52+
```Python
53+
f = open("model_weights.dat", "wb")
54+
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
55+
f.close()
56+
```

src/Examples/AdversarialExampleGeneration.cs

+52-79
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
namespace TorchSharp.Examples
1919
{
2020
/// <summary>
21-
/// Simple MNIST Convolutional model.
21+
/// FGSM Attack
22+
///
23+
/// Based on : https://pytorch.org/tutorials/beginner/fgsm_tutorial.html
2224
/// </summary>
2325
/// <remarks>
2426
/// There are at least two interesting data sets to use with this example:
@@ -34,6 +36,13 @@ namespace TorchSharp.Examples
3436
///
3537
/// In each case, there are four .gz files to download. Place them in a folder and then point the '_dataLocation'
3638
/// constant below at the folder location.
39+
///
40+
/// The example is based on the PyTorch tutorial, but the results from attacking the model are very different from
41+
/// what the tutorial article notes, at least on the machine where it was developed. There is an order-of-magnitude lower
42+
/// drop-off in accuracy in this version. That said, when running the PyTorch tutorial on the same machine, the
43+
/// accuracy trajectories are the same between .NET and Python. If the base convulutational model is trained
44+
/// using Python, and then used for the FGSM attack in both .NET and Python, the drop-off trajectories are extremenly
45+
/// close.
3746
/// </remarks>
3847
public class AdversarialExampleGeneration
3948
{
@@ -74,114 +83,78 @@ static void Main(string[] args)
7483
Utils.Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-labels-idx1-ubyte.gz"), targetDir);
7584
}
7685

86+
MNIST.Model model = null;
87+
7788
var normImage = TorchVision.Transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }, device: device);
7889

79-
using (var train = new MNISTReader(targetDir, "train", _trainBatchSize, device: device, shuffle: true, transform: normImage))
8090
using (var test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage)) {
8191

82-
var model = new Model("model", Device.CPU);
83-
8492
var modelFile = dataset + ".model.bin";
8593

8694
if (!File.Exists(modelFile)) {
8795
// We need the model to be trained first, because we want to start with a trained model.
8896
Console.WriteLine($"\n Running MNIST on {device.Type.ToString()} in order to pre-train the model.");
89-
MNIST.TrainingLoop(dataset, device, train, test);
90-
Console.WriteLine("Moving on to the Adversarial model.\n");
91-
}
9297

93-
model.load(modelFile);
94-
model.to(device);
98+
model = new MNIST.Model("model", device);
9599

96-
// Establish a baseline accuracy.
100+
using (MNISTReader train = new MNISTReader(targetDir, "train", _trainBatchSize, device: device, shuffle: true, transform: normImage)) {
101+
MNIST.TrainingLoop(dataset, device, model, train, test);
102+
}
97103

98-
Stopwatch sw = new Stopwatch();
99-
sw.Start();
104+
Console.WriteLine("Moving on to the Adversarial model.\n");
100105

101-
var baseline = TestBaseline(model, nll_loss(reduction: NN.Reduction.Sum), test, test.Size);
106+
} else {
107+
model = new MNIST.Model("model", Device.CPU);
108+
model.load(modelFile);
109+
}
102110

103-
Console.WriteLine($"\rBaseline model accuracy: {baseline}");
111+
model.to(device);
112+
model.Eval();
104113

105-
sw.Stop();
106-
Console.WriteLine($"Elapsed time: {sw.Elapsed.TotalSeconds} s.");
114+
var epsilons = new double[] { 0, 0.05, 0.1, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50 };
107115

108-
GC.Collect();
116+
foreach (var ε in epsilons) {
117+
var attacked = Test(model, nll_loss(), ε, test, test.Size);
118+
Console.WriteLine($"Epsilon: {ε:F2}, accuracy: {attacked:P2}");
119+
}
109120
}
110121
}
111122

112-
private class Model : CustomModule
123+
private static TorchTensor Attack(TorchTensor image, double ε, TorchTensor data_grad)
113124
{
114-
private Conv2d conv1 = Conv2d(1, 32, 3);
115-
private Conv2d conv2 = Conv2d(32, 64, 3);
116-
private Linear fc1 = Linear(9216, 128);
117-
private Linear fc2 = Linear(128, 10);
118-
119-
// These don't have any parameters, so the only reason to instantiate
120-
// them is performance, since they will be used over and over.
121-
private MaxPool2d pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 });
122-
123-
private ReLU relu1 = ReLU();
124-
private ReLU relu2 = ReLU();
125-
private ReLU relu3 = ReLU();
126-
127-
private FeatureAlphaDropout dropout1 = FeatureAlphaDropout();
128-
private Dropout dropout2 = Dropout();
129-
130-
private Flatten flatten = Flatten();
131-
private LogSoftmax logsm = LogSoftmax(1);
132-
133-
134-
public Model(string name, Device device = null) : base(name)
135-
{
136-
RegisterComponents();
137-
138-
if (device != null && device.Type == DeviceType.CUDA)
139-
this.to(device);
140-
}
141-
142-
public override TorchTensor forward(TorchTensor input)
143-
{
144-
var l11 = conv1.forward(input);
145-
var l12 = relu2.forward(l11);
146-
147-
var l21 = conv2.forward(l12);
148-
var l22 = pool1.forward(l21);
149-
var l23 = dropout1.forward(l22);
150-
var l24 = relu2.forward(l23);
151-
152-
var x = flatten.forward(l24);
153-
154-
var l31 = fc1.forward(x);
155-
var l32 = relu3.forward(l31);
156-
var l33 = dropout2.forward(l32);
157-
158-
var l41 = fc2.forward(l33);
159-
160-
return logsm.forward(l41);
125+
using (var sign = data_grad.sign()) {
126+
var perturbed = (image + ε * sign).clamp(0.0, 1.0);
127+
return perturbed;
161128
}
162129
}
163130

164-
private static double TestBaseline(
165-
Model model,
166-
Loss loss,
131+
private static double Test(
132+
MNIST.Model model,
133+
Loss criterion,
134+
double ε,
167135
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
168136
long size)
169137
{
170-
model.Eval();
171-
172-
double testLoss = 0;
173138
int correct = 0;
174139

175-
foreach (var (data, target) in dataLoader)
176-
{
177-
var prediction = model.forward(data);
178-
var output = loss(prediction, target);
179-
testLoss += output.ToSingle();
140+
foreach (var (data, target) in dataLoader) {
141+
142+
data.requires_grad = true;
143+
144+
using (var output = model.forward(data))
145+
using (var loss = criterion(output, target)) {
180146

181-
var pred = prediction.argmax(1);
182-
correct += pred.eq(target).sum().ToInt32();
147+
model.ZeroGrad();
148+
loss.backward();
149+
150+
var perturbed = Attack(data, ε, data.grad());
151+
152+
using (var final = model.forward(perturbed)) {
153+
154+
correct += final.argmax(1).eq(target).sum().ToInt32();
155+
}
156+
}
183157

184-
pred.Dispose();
185158

186159
GC.Collect();
187160
}

src/Examples/Examples.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
<TestUsesLibTorch>true</TestUsesLibTorch>
99
<UseMLCodeAnalyzer>false</UseMLCodeAnalyzer>
1010
<UseStyleCopAnalyzer>false</UseStyleCopAnalyzer>
11-
<StartupObject>TorchSharp.Examples.MNIST</StartupObject>
11+
<StartupObject>TorchSharp.Examples.AdversarialExampleGeneration</StartupObject>
1212
<IsPackable>false</IsPackable>
1313
<PlatformTarget>x64</PlatformTarget>
1414
<RootNamespace>TorchSharp.Examples</RootNamespace>

src/Examples/MNIST.cs

+17-14
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ static void Main(string[] args)
5050

5151
var cwd = Environment.CurrentDirectory;
5252

53-
//var device = Device.CPU;
5453
var device = Torch.IsCudaAvailable() ? Device.CUDA : Device.CPU;
5554
Console.WriteLine($"Running MNIST on {device.Type.ToString()}");
5655
Console.WriteLine($"Dataset: {dataset}");
@@ -69,21 +68,25 @@ static void Main(string[] args)
6968
if (device.Type == DeviceType.CUDA) {
7069
_trainBatchSize *= 4;
7170
_testBatchSize *= 4;
72-
_epochs *= 4;
7371
}
7472

73+
var model = new Model("model", device);
74+
7575
var normImage = TorchVision.Transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }, device: device);
7676

7777
using (MNISTReader train = new MNISTReader(targetDir, "train", _trainBatchSize, device: device, shuffle: true, transform: normImage),
7878
test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage)) {
7979

80-
TrainingLoop(dataset, device, train, test);
80+
TrainingLoop(dataset, device, model, train, test);
8181
}
8282
}
8383

84-
internal static void TrainingLoop(string dataset, Device device, MNISTReader train, MNISTReader test)
84+
internal static void TrainingLoop(string dataset, Device device, Model model, MNISTReader train, MNISTReader test)
8585
{
86-
var model = new Model("model", device);
86+
if (device.Type == DeviceType.CUDA) {
87+
_epochs *= 4;
88+
}
89+
8790
var optimizer = NN.Optimizer.Adam(model.parameters());
8891

8992
var scheduler = NN.Optimizer.StepLR(optimizer, 1, 0.7, last_epoch: 5);
@@ -100,13 +103,13 @@ internal static void TrainingLoop(string dataset, Device device, MNISTReader tra
100103
}
101104

102105
sw.Stop();
103-
Console.WriteLine($"Elapsed time: {sw.Elapsed.TotalSeconds} s.");
106+
Console.WriteLine($"Elapsed time: {sw.Elapsed.TotalSeconds:F1} s.");
104107

105108
Console.WriteLine("Saving model to '{0}'", dataset + ".model.bin");
106109
model.save(dataset + ".model.bin");
107110
}
108111

109-
private class Model : CustomModule
112+
internal class Model : CustomModule
110113
{
111114
private Conv2d conv1 = Conv2d(1, 32, 3);
112115
private Conv2d conv2 = Conv2d(32, 64, 3);
@@ -121,8 +124,8 @@ private class Model : CustomModule
121124
private ReLU relu2 = ReLU();
122125
private ReLU relu3 = ReLU();
123126

124-
private FeatureAlphaDropout dropout1 = FeatureAlphaDropout();
125-
private Dropout dropout2 = Dropout();
127+
private Dropout dropout1 = Dropout(0.25);
128+
private Dropout dropout2 = Dropout(0.5);
126129

127130
private Flatten flatten = Flatten();
128131
private LogSoftmax logsm = LogSoftmax(1);
@@ -141,9 +144,9 @@ public override TorchTensor forward(TorchTensor input)
141144
var l12 = relu2.forward(l11);
142145

143146
var l21 = conv2.forward(l12);
144-
var l22 = pool1.forward(l21);
145-
var l23 = dropout1.forward(l22);
146-
var l24 = relu2.forward(l23);
147+
var l22 = relu2.forward(l21);
148+
var l23 = pool1.forward(l22);
149+
var l24 = dropout1.forward(l23);
147150

148151
var x = flatten.forward(l24);
149152

@@ -184,7 +187,7 @@ private static void Train(
184187
optimizer.step();
185188

186189
if (batchId % _logInterval == 0) {
187-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle()}");
190+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle():F4}");
188191
}
189192

190193
batchId++;
@@ -220,7 +223,7 @@ private static void Test(
220223

221224
Console.WriteLine($"Size: {size}, Total: {size}");
222225

223-
Console.WriteLine($"\rTest set: Average loss {testLoss / size} | Accuracy {(double)correct / size}");
226+
Console.WriteLine($"\rTest set: Average loss {(testLoss / size):F4} | Accuracy {((double)correct / size):P2}");
224227
}
225228
}
226229
}

src/Examples/MNISTReader.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public MNISTReader(string path, string prefix, int batch_size = 32, bool shuffle
9292
var idx = indices[i++];
9393
var imgStart = idx * imgSize;
9494

95-
var floats = dataBytes[imgStart.. (imgStart+imgSize)].Select(b => (float)b).ToArray();
95+
var floats = dataBytes[imgStart.. (imgStart+imgSize)].Select(b => b/256.0f).ToArray();
9696
using (var inputTensor = Float32Tensor.from(floats))
9797
dataTensor.index_put_(new TorchTensorIndex [] { TorchTensorIndex.Single(j) }, inputTensor);
9898
lablTensor[j] = Int64Tensor.from(labelBytes[idx]);

0 commit comments

Comments
 (0)