Skip to content

Commit 03130f5

Browse files
committed
datasets bug fix
1 parent e7f82f5 commit 03130f5

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

dataset.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
import spacy
2121

22-
def split_train_valid(frac=0.7, path_data, path_train, path_valid):
22+
def split_train_valid(path_data, path_train, path_valid, frac=0.7):
2323
df = pd.read_csv(path_data)
2424
rng = RandomState()
2525
tr = df.sample(frac=0.7, random_state=rng)
2626
tst = df.loc[~df.index.isin(tr.index)]
27+
print("spliting original file to train/valid set...")
2728
tr.to_csv(path_train, index=False)
2829
tst.to_csv(path_valid, index=False)
2930

@@ -50,8 +51,8 @@ def clean_str(string):
5051
string = re.sub(r"\s{2,}", " ", string)
5152
return string.strip()
5253

53-
def create_tabular_dataset(lang='en', pretrained_emb='glove.6B.300d',
54-
path_train, path_valid):
54+
def create_tabular_dataset(path_train, path_valid,
55+
lang='en', pretrained_emb='glove.6B.300d'):
5556

5657
spacy_en = spacy.load('en', disable=['tagger', 'parser', 'ner', 'textcat'
5758
'entity_ruler', 'sentencizer',
@@ -72,7 +73,7 @@ def tokenizer(text):
7273
print('Creating tabular datasets...')
7374
train_datafield = [('text', TEXT), ('label', LABEL)]
7475
tabular_train = TabularDataset(path = path_train,
75-
format= path_valid,
76+
format= 'csv',
7677
skip_header=True,
7778
fields=train_datafield)
7879

@@ -81,11 +82,11 @@ def tokenizer(text):
8182
tabular_valid = TabularDataset(path = path_valid,
8283
format='csv',
8384
skip_header=True,
84-
fields=test_datafield)
85+
fields=valid_datafield)
8586

8687
print('Building vocaulary...')
87-
TEXT.build_vocab(train, vectors= pretrained_emb)
88-
LABEL.build_vocab(train)
88+
TEXT.build_vocab(tabular_train, vectors= pretrained_emb)
89+
LABEL.build_vocab(tabular_train)
8990

9091

9192
return tabular_train, tabular_valid, TEXT.vocab
@@ -95,14 +96,14 @@ def create_data_iterator(tr_batch_size, val_batch_size,tabular_train,
9596
#Create the Iterator for datasets (Iterator works like dataloader)
9697

9798
train_iter = Iterator(
98-
train,
99+
tabular_train,
99100
batch_size=tr_batch_size,
100101
device= d,
101102
sort_within_batch=False,
102103
repeat=False)
103104

104105
valid_iter = Iterator(
105-
test,
106+
tabular_valid,
106107
batch_size=val_batch_size,
107108
device=d,
108109
sort_within_batch=False,

model.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class textCNN(nn.Module):
1515
def __init__(self, vocab_built, emb_dim, dim_channel, kernel_wins, num_class):
1616
super(textCNN, self).__init__()
1717
#load pretrained embedding in embedding layer.
18+
print(vocab_built)
1819
self.embed = nn.Embedding(len(vocab_built), emb_dim)
1920
self.embed.weight.data.copy_(vocab_built.vectors)
2021

0 commit comments

Comments
 (0)