19
19
20
20
import spacy
21
21
22
- def split_train_valid (frac = 0.7 , path_data , path_train , path_valid ):
22
+ def split_train_valid (path_data , path_train , path_valid , frac = 0.7 ):
23
23
df = pd .read_csv (path_data )
24
24
rng = RandomState ()
25
25
tr = df .sample (frac = 0.7 , random_state = rng )
26
26
tst = df .loc [~ df .index .isin (tr .index )]
27
+ print ("spliting original file to train/valid set..." )
27
28
tr .to_csv (path_train , index = False )
28
29
tst .to_csv (path_valid , index = False )
29
30
@@ -50,8 +51,8 @@ def clean_str(string):
50
51
string = re .sub (r"\s{2,}" , " " , string )
51
52
return string .strip ()
52
53
53
- def create_tabular_dataset (lang = 'en' , pretrained_emb = 'glove.6B.300d' ,
54
- path_train , path_valid ):
54
+ def create_tabular_dataset (path_train , path_valid ,
55
+ lang = 'en' , pretrained_emb = 'glove.6B.300d' ):
55
56
56
57
spacy_en = spacy .load ('en' , disable = ['tagger' , 'parser' , 'ner' , 'textcat'
57
58
'entity_ruler' , 'sentencizer' ,
@@ -72,7 +73,7 @@ def tokenizer(text):
72
73
print ('Creating tabular datasets...' )
73
74
train_datafield = [('text' , TEXT ), ('label' , LABEL )]
74
75
tabular_train = TabularDataset (path = path_train ,
75
- format = path_valid ,
76
+ format = 'csv' ,
76
77
skip_header = True ,
77
78
fields = train_datafield )
78
79
@@ -81,11 +82,11 @@ def tokenizer(text):
81
82
tabular_valid = TabularDataset (path = path_valid ,
82
83
format = 'csv' ,
83
84
skip_header = True ,
84
- fields = test_datafield )
85
+ fields = valid_datafield )
85
86
86
87
print ('Building vocaulary...' )
87
- TEXT .build_vocab (train , vectors = pretrained_emb )
88
- LABEL .build_vocab (train )
88
+ TEXT .build_vocab (tabular_train , vectors = pretrained_emb )
89
+ LABEL .build_vocab (tabular_train )
89
90
90
91
91
92
return tabular_train , tabular_valid , TEXT .vocab
@@ -95,14 +96,14 @@ def create_data_iterator(tr_batch_size, val_batch_size,tabular_train,
95
96
#Create the Iterator for datasets (Iterator works like dataloader)
96
97
97
98
train_iter = Iterator (
98
- train ,
99
+ tabular_train ,
99
100
batch_size = tr_batch_size ,
100
101
device = d ,
101
102
sort_within_batch = False ,
102
103
repeat = False )
103
104
104
105
valid_iter = Iterator (
105
- test ,
106
+ tabular_valid ,
106
107
batch_size = val_batch_size ,
107
108
device = d ,
108
109
sort_within_batch = False ,
0 commit comments