Skip to content

Commit fcae7b5

Browse files
committed
finish modularization
1 parent 03130f5 commit fcae7b5

File tree

6 files changed

+50162
-465
lines changed

6 files changed

+50162
-465
lines changed

IMDB_Dataset.csv

+50,001
Large diffs are not rendered by default.

dataset.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def split_train_valid(path_data, path_train, path_valid, frac=0.7):
2424
rng = RandomState()
2525
tr = df.sample(frac=0.7, random_state=rng)
2626
tst = df.loc[~df.index.isin(tr.index)]
27-
print("spliting original file to train/valid set...")
27+
print("Spliting original file to train/valid set...")
2828
tr.to_csv(path_train, index=False)
2929
tst.to_csv(path_valid, index=False)
3030

@@ -70,7 +70,7 @@ def tokenizer(text):
7070
#clean the text
7171
TEXT.preprocessing = torchtext.data.Pipeline(clean_str)
7272

73-
print('Creating tabular datasets...')
73+
print('Creating tabular datasets...It might take a while to finish!')
7474
train_datafield = [('text', TEXT), ('label', LABEL)]
7575
tabular_train = TabularDataset(path = path_train,
7676
format= 'csv',
@@ -98,13 +98,13 @@ def create_data_iterator(tr_batch_size, val_batch_size,tabular_train,
9898
train_iter = Iterator(
9999
tabular_train,
100100
batch_size=tr_batch_size,
101-
device= d,
101+
device = d,
102102
sort_within_batch=False,
103103
repeat=False)
104104

105105
valid_iter = Iterator(
106106
tabular_valid,
107-
batch_size=val_batch_size,
107+
batch_size=val_batch_size,
108108
device=d,
109109
sort_within_batch=False,
110110
repeat=False)

main.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# -*- coding: utf-8 -*-
2+
"""main.ipynb
3+
4+
Automatically generated by Colaboratory.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/17VqaftLE6Xo9iUJryPp7DbtrZNmNHU1J
8+
"""
9+
10+
import argparse
11+
12+
import torch
13+
import torch.optim as optim
14+
15+
import dataset
16+
import model
17+
import training
18+
19+
import matplotlib.pyplot as plt
20+
21+
22+
23+
#%%
24+
25+
def main():
26+
27+
print("Pytorch Version:", torch.__version__)
28+
parser = argparse.ArgumentParser(description='TextCNN')
29+
#Training args
30+
parser.add_argument('--data-csv', type=str, default='./IMDB_Dataset.csv', metavar='./IMDB_Dataset.csv',
31+
help='file path of training data in CSV forma (default: ./train.csv)')
32+
33+
parser.add_argument('--spacy-lang', type=str, default='en', metavar='en',
34+
help='language choice for spacy to tokenize the text (ex: en or fr')
35+
36+
parser.add_argument('--pretrained', type=str, default='glove.6B.300d', metavar='glove.6B.300d',
37+
help='choice of pretrined word embedding from torchtext')
38+
39+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
40+
help='number of epochs to train (default: 10)')
41+
42+
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
43+
help='learning rate (default: 0.01)')
44+
45+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
46+
help='SGD momentum (default: 0.9)')
47+
48+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
49+
help='input batch size for training (default: 64)')
50+
51+
parser.add_argument('--val-batch-size', type=int, default=64, metavar='N',
52+
help='input batch size for testing (default: 64)')
53+
54+
parser.add_argument('--kernel-height', type=str, default='3,4,5', metavar='S',
55+
help='how many kernel width for convolution (default: 3, 4, 5)')
56+
57+
parser.add_argument('--out-channel', type=int, default=100, metavar='N',
58+
help='output channel for convolutionaly layer (default: 100)')
59+
60+
parser.add_argument('--dropout', type=int, default=0.5, metavar='N',
61+
help='dropout rate for linear layer (default: 0.5)')
62+
63+
parser.add_argument('--num-class', type=int, default=2, metavar='N',
64+
help='number of category to classify (default: 2)')
65+
66+
#if you are using jupyternotebook with argparser
67+
args = parser.parse_known_args()[0]
68+
#args = parser.parse_args()
69+
70+
71+
#Use GPU if it is available
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
74+
75+
#%% Split whole dataset into train and valid set
76+
dataset.split_train_valid(args.data_csv, './train.csv', './valid.csv', 0.7)
77+
78+
trainset, validset, vocab = dataset.create_tabular_dataset('./train.csv',
79+
'./valid.csv',args.spacy_lang, args.pretrained)
80+
81+
#%%Show some example to show the dataset
82+
print("Show some examples from train/valid..")
83+
print(trainset[0].text, trainset[0].label)
84+
print(validset[0].text, validset[0].label)
85+
86+
train_iter, valid_iter = dataset.create_data_iterator(args.batch_size, args.val_batch_size,
87+
trainset, validset,device)
88+
89+
#%%Create
90+
kernels = [int(x) for x in args.kernel_height.split(',')]
91+
m = model.textCNN(vocab, args.out_channel, kernels , args.num_class).to(device)
92+
# print the model summery
93+
print(m)
94+
95+
train_loss = []
96+
train_acc = []
97+
test_loss = []
98+
test_acc = []
99+
best_test_acc = -1
100+
101+
#optimizer
102+
optimizer = optim.Adam(m.parameters(), lr=args.lr)
103+
104+
for epoch in range(1, args.epochs+1):
105+
#train loss
106+
tr_loss, tr_acc = training.train(m, device, train_iter, optimizer, epoch, args.epochs)
107+
print('Train Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, tr_loss, tr_acc))
108+
109+
ts_loss, ts_acc = training.valid(m, device, valid_iter)
110+
print('Valid Epoch: {} \t Loss: {} \t Accuracy: {}%'.format(epoch, ts_loss, ts_acc))
111+
112+
if ts_acc > best_test_acc:
113+
best_test_acc = ts_acc
114+
#save paras(snapshot)
115+
print("model saves at {}% accuracy".format(best_test_acc))
116+
torch.save(m.state_dict(), "best_validation")
117+
118+
train_loss.append(tr_loss)
119+
train_acc.append(tr_acc)
120+
test_loss.append(ts_loss)
121+
test_acc.append(ts_acc)
122+
123+
#plot train/validation loss versus epoch
124+
x = list(range(1, args.epochs+1))
125+
plt.figure()
126+
plt.title("train/validation loss versus epoch")
127+
plt.xlabel("epoch")
128+
plt.ylabel("total loss")
129+
plt.plot(x, train_loss,label="train loss")
130+
plt.plot(x, test_loss, color='red', label="test loss")
131+
plt.legend(loc='upper right')
132+
plt.grid(True)
133+
plt.show()
134+
135+
#plot train/validation accuracy versus epoch
136+
x = list(range(1, args.epochs+1))
137+
plt.figure()
138+
plt.title("train/validation loss versus epoch")
139+
plt.xlabel("epoch")
140+
plt.ylabel("total loss")
141+
plt.plot(x, train_acc,label="train accuracy")
142+
plt.plot(x, test_acc, color='red', label="test accuracy")
143+
plt.legend(loc='upper right')
144+
plt.grid(True)
145+
plt.show()
146+
147+
if __name__ == '__main__':
148+
main()
149+

model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#%% Text CNN model
1313
class textCNN(nn.Module):
1414

15-
def __init__(self, vocab_built, emb_dim, dim_channel, kernel_wins, num_class):
15+
def __init__(self, vocab_built, dim_channel, kernel_wins, num_class):
1616
super(textCNN, self).__init__()
1717
#load pretrained embedding in embedding layer.
18-
print(vocab_built)
18+
emb_dim = vocab_built.vectors.size()[1]
1919
self.embed = nn.Embedding(len(vocab_built), emb_dim)
2020
self.embed.weight.data.copy_(vocab_built.vectors)
2121

0 commit comments

Comments
 (0)