|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""main.ipynb |
| 3 | +
|
| 4 | +Automatically generated by Colaboratory. |
| 5 | +
|
| 6 | +Original file is located at |
| 7 | + https://colab.research.google.com/drive/17VqaftLE6Xo9iUJryPp7DbtrZNmNHU1J |
| 8 | +""" |
| 9 | + |
| 10 | +import argparse |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.optim as optim |
| 14 | + |
| 15 | +import dataset |
| 16 | +import model |
| 17 | +import training |
| 18 | + |
| 19 | +import matplotlib.pyplot as plt |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +#%% |
| 24 | + |
| 25 | +def main(): |
| 26 | + |
| 27 | + print("Pytorch Version:", torch.__version__) |
| 28 | + parser = argparse.ArgumentParser(description='TextCNN') |
| 29 | + #Training args |
| 30 | + parser.add_argument('--data-csv', type=str, default='./IMDB_Dataset.csv', metavar='./IMDB_Dataset.csv', |
| 31 | + help='file path of training data in CSV forma (default: ./train.csv)') |
| 32 | + |
| 33 | + parser.add_argument('--spacy-lang', type=str, default='en', metavar='en', |
| 34 | + help='language choice for spacy to tokenize the text (ex: en or fr') |
| 35 | + |
| 36 | + parser.add_argument('--pretrained', type=str, default='glove.6B.300d', metavar='glove.6B.300d', |
| 37 | + help='choice of pretrined word embedding from torchtext') |
| 38 | + |
| 39 | + parser.add_argument('--epochs', type=int, default=10, metavar='N', |
| 40 | + help='number of epochs to train (default: 10)') |
| 41 | + |
| 42 | + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', |
| 43 | + help='learning rate (default: 0.01)') |
| 44 | + |
| 45 | + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
| 46 | + help='SGD momentum (default: 0.9)') |
| 47 | + |
| 48 | + parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| 49 | + help='input batch size for training (default: 64)') |
| 50 | + |
| 51 | + parser.add_argument('--val-batch-size', type=int, default=64, metavar='N', |
| 52 | + help='input batch size for testing (default: 64)') |
| 53 | + |
| 54 | + parser.add_argument('--kernel-height', type=str, default='3,4,5', metavar='S', |
| 55 | + help='how many kernel width for convolution (default: 3, 4, 5)') |
| 56 | + |
| 57 | + parser.add_argument('--out-channel', type=int, default=100, metavar='N', |
| 58 | + help='output channel for convolutionaly layer (default: 100)') |
| 59 | + |
| 60 | + parser.add_argument('--dropout', type=int, default=0.5, metavar='N', |
| 61 | + help='dropout rate for linear layer (default: 0.5)') |
| 62 | + |
| 63 | + parser.add_argument('--num-class', type=int, default=2, metavar='N', |
| 64 | + help='number of category to classify (default: 2)') |
| 65 | + |
| 66 | + #if you are using jupyternotebook with argparser |
| 67 | + args = parser.parse_known_args()[0] |
| 68 | + #args = parser.parse_args() |
| 69 | + |
| 70 | + |
| 71 | + #Use GPU if it is available |
| 72 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 73 | + |
| 74 | + |
| 75 | + #%% Split whole dataset into train and valid set |
| 76 | + dataset.split_train_valid(args.data_csv, './train.csv', './valid.csv', 0.7) |
| 77 | + |
| 78 | + trainset, validset, vocab = dataset.create_tabular_dataset('./train.csv', |
| 79 | + './valid.csv',args.spacy_lang, args.pretrained) |
| 80 | + |
| 81 | + #%%Show some example to show the dataset |
| 82 | + print("Show some examples from train/valid..") |
| 83 | + print(trainset[0].text, trainset[0].label) |
| 84 | + print(validset[0].text, validset[0].label) |
| 85 | + |
| 86 | + train_iter, valid_iter = dataset.create_data_iterator(args.batch_size, args.val_batch_size, |
| 87 | + trainset, validset,device) |
| 88 | + |
| 89 | + #%%Create |
| 90 | + kernels = [int(x) for x in args.kernel_height.split(',')] |
| 91 | + m = model.textCNN(vocab, args.out_channel, kernels , args.num_class).to(device) |
| 92 | + # print the model summery |
| 93 | + print(m) |
| 94 | + |
| 95 | + train_loss = [] |
| 96 | + train_acc = [] |
| 97 | + test_loss = [] |
| 98 | + test_acc = [] |
| 99 | + best_test_acc = -1 |
| 100 | + |
| 101 | + #optimizer |
| 102 | + optimizer = optim.Adam(m.parameters(), lr=args.lr) |
| 103 | + |
| 104 | + for epoch in range(1, args.epochs+1): |
| 105 | + #train loss |
| 106 | + tr_loss, tr_acc = training.train(m, device, train_iter, optimizer, epoch, args.epochs) |
| 107 | + print('Train Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, tr_loss, tr_acc)) |
| 108 | + |
| 109 | + ts_loss, ts_acc = training.valid(m, device, valid_iter) |
| 110 | + print('Valid Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, ts_loss, ts_acc)) |
| 111 | + |
| 112 | + if ts_acc > best_test_acc: |
| 113 | + best_test_acc = ts_acc |
| 114 | + #save paras(snapshot) |
| 115 | + print("model saves at {}% accuracy".format(best_test_acc)) |
| 116 | + torch.save(m.state_dict(), "best_validation") |
| 117 | + |
| 118 | + train_loss.append(tr_loss) |
| 119 | + train_acc.append(tr_acc) |
| 120 | + test_loss.append(ts_loss) |
| 121 | + test_acc.append(ts_acc) |
| 122 | + |
| 123 | + #plot train/validation loss versus epoch |
| 124 | + x = list(range(1, args.epochs+1)) |
| 125 | + plt.figure() |
| 126 | + plt.title("train/validation loss versus epoch") |
| 127 | + plt.xlabel("epoch") |
| 128 | + plt.ylabel("total loss") |
| 129 | + plt.plot(x, train_loss,label="train loss") |
| 130 | + plt.plot(x, test_loss, color='red', label="test loss") |
| 131 | + plt.legend(loc='upper right') |
| 132 | + plt.grid(True) |
| 133 | + plt.show() |
| 134 | + |
| 135 | + #plot train/validation accuracy versus epoch |
| 136 | + x = list(range(1, args.epochs+1)) |
| 137 | + plt.figure() |
| 138 | + plt.title("train/validation loss versus epoch") |
| 139 | + plt.xlabel("epoch") |
| 140 | + plt.ylabel("total loss") |
| 141 | + plt.plot(x, train_acc,label="train accuracy") |
| 142 | + plt.plot(x, test_acc, color='red', label="test accuracy") |
| 143 | + plt.legend(loc='upper right') |
| 144 | + plt.grid(True) |
| 145 | + plt.show() |
| 146 | + |
| 147 | +if __name__ == '__main__': |
| 148 | + main() |
| 149 | + |
0 commit comments