Skip to content

Commit f029709

Browse files
committed
Merged with predict branch.
2 parents 6e6e4b6 + 41f4f08 commit f029709

20 files changed

+165
-72
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

root/dataset/api.py renamed to dataset/api.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
import itertools
21
from collections import namedtuple
32

43
import numpy as np
54
from keras.utils import to_categorical
6-
from keras_preprocessing.sequence import pad_sequences
7-
from root.constants import NO_ENTITY_TOKEN, MAX_LEN, PAD, MAX_LEN_CHAR
5+
6+
from constants import NO_ENTITY_TOKEN, MAX_LEN, PAD, MAX_LEN_CHAR
87
from .data_processor import numericalize
98
from .vocab import TextVocab, LabelVocab, PosVocab, CharacterVocab
109

1110

1211
def load_dataset():
1312
# load examples
14-
train_examples = load_examples('../dataset/raw/train.txt')
15-
val_examples = load_examples('../dataset/raw/valid.txt')
16-
test_examples = load_examples('../dataset/raw/test.txt')
13+
train_examples = load_examples('data/raw/train.txt')
14+
val_examples = load_examples('data/raw/valid.txt')
15+
test_examples = load_examples('data/raw/test.txt')
1716

1817
# build vocabularies
1918
text_vocab = TextVocab.build(list(map(lambda e: e.sentence, train_examples)))
@@ -32,7 +31,7 @@ def load_examples(file_path):
3231
"""
3332
Loads sentences from file in CoNLL 2003 format.
3433
35-
:param file_path: Path to file with CoNLL data.
34+
:param file_path: Path to file with CoNLL dataset.
3635
:return: list(Example)
3736
"""
3837
examples = []
@@ -57,8 +56,10 @@ def load_examples(file_path):
5756

5857
sentence.append(parts[0])
5958

60-
if parts[1] in ['$', '"', '(', ')', "''", '.', ':', ',']:
61-
pos.append('NN')
59+
if parts[1] == '(':
60+
pos.append('-LRB-')
61+
elif parts[1] == ')':
62+
pos.append('-RRB-')
6263
else:
6364
pos.append(parts[1])
6465

root/dataset/data_processor.py renamed to dataset/data_processor.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from keras import preprocessing
21
from keras_preprocessing.sequence import pad_sequences
32

43

root/dataset/vocab.py renamed to dataset/vocab.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from collections import defaultdict, Counter
3-
from root.constants import PAD, UNK
3+
from constants import PAD, UNK
44

55

66
class Vocab(ABC):
@@ -34,12 +34,16 @@ def build(sentences, max_size=None):
3434

3535
vocab = TextVocab()
3636
vocab._itos = [PAD, UNK] + list(map(lambda t: t[0], words_and_freqs))
37-
vocab._stoi = defaultdict(lambda: 1) # index of UNK token
37+
vocab._stoi = defaultdict(_unk_token_idx) # index of UNK token
3838
vocab.stoi.update({k: v for v, k in enumerate(vocab.itos)})
3939

4040
return vocab
4141

4242

43+
def _unk_token_idx():
44+
return 1
45+
46+
4347
class LabelVocab(Vocab):
4448
@staticmethod
4549
def build(sentences):
@@ -63,7 +67,7 @@ def build(sentences):
6367

6468
vocab = PosVocab()
6569
vocab._itos = [PAD] + list(unique_pos)
66-
vocab._stoi = defaultdict(lambda: 1)
70+
vocab._stoi = defaultdict(_unk_token_idx)
6771
vocab.stoi.update({k: v for v, k in enumerate(vocab.itos)})
6872

6973
return vocab
@@ -76,7 +80,7 @@ def build(words):
7680
chars = list(map(lambda c: c, " 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,-_()[]{}!?:;#'\"/\\%$`&=*+@^~|"))
7781

7882
vocab._itos = [PAD, UNK] + chars
79-
vocab._stoi = defaultdict(lambda: 2)
83+
vocab._stoi = defaultdict(_unk_token_idx)
8084
vocab.stoi.update({k: v for v, k in enumerate(vocab.itos)})
8185

8286
return vocab

embedding/glove.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import numpy as np
33

4-
GLOVE_DIR = '../embedding/glove.6B.100d.txt'
4+
GLOVE_DIR = 'embedding/glove.6B.100d.txt'
55

66

77
def get_pretrained_glove(num_words, text_vocab):

root/model.py renamed to model.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from keras.utils.vis_utils import plot_model
88
from keras.callbacks import TensorBoard
99

10-
from root.constants import MAX_LEN, MAX_LEN_CHAR
10+
from constants import MAX_LEN, MAX_LEN_CHAR
1111

1212

1313
class NeuralNetwork(object):
1414

15-
def __init__(self, num_words, num_entities, num_pos, num_chars, train, test, validation):
15+
def __init__(self, save_path, num_words, num_entities, num_pos, num_chars, train, test, validation):
1616
self.num_words = num_words
1717
self.num_entities = num_entities
1818
self.num_pos = num_pos
@@ -26,6 +26,7 @@ def __init__(self, num_words, num_entities, num_pos, num_chars, train, test, val
2626
self.train_pos = train.pos
2727
self.test_pos = test.pos
2828
self.valid_pos = validation.pos
29+
self.save_path = save_path
2930

3031
self.train_characters = train.characters
3132
self.test_characters = test.characters
@@ -56,19 +57,16 @@ def train(self, epochs, embedding=None):
5657
# Deep Layers
5758
model = Bidirectional(LSTM(units=100, return_sequences=True, recurrent_dropout=0.1))(x)
5859
model = Bidirectional(LSTM(units=100, return_sequences=True, recurrent_dropout=0.1))(model)
59-
model = Bidirectional(LSTM(units=100, return_sequences=True, recurrent_dropout=0.1))(model)
6060

6161
# Output
6262
out = TimeDistributed(Dense(self.num_entities, activation="softmax"))(model)
6363
model = Model(inputs=[txt_input, pos_input, char_in], outputs=[out])
6464

6565
model.compile(optimizer="rmsprop", loss='categorical_crossentropy', metrics=['accuracy'])
6666

67-
plot_model(model, to_file='../models/ner_model_image.png')
67+
plot_model(model, to_file=self.save_path + 'ner_model_image.png')
6868
print(model.summary())
6969

70-
model.compile(optimizer="rmsprop", metrics=['accuracy'], loss='categorical_crossentropy')
71-
7270
dir = create_dir()
7371

7472
tensorboard_callback = TensorBoard(log_dir=dir, histogram_freq=0, write_graph=True, write_images=True)
@@ -82,7 +80,7 @@ def train(self, epochs, embedding=None):
8280
np.array(self.Y_validation)),
8381
callbacks=[tensorboard_callback], verbose=1)
8482

85-
model.save("../models/ner_model")
83+
model.save(self.save_path + 'ner_model')
8684

8785
test_eval = model.evaluate(
8886
[self.X_test, self.test_pos, np.array(self.test_characters).reshape((len(self.test_characters), MAX_LEN, MAX_LEN_CHAR))],
@@ -94,16 +92,16 @@ def train(self, epochs, embedding=None):
9492

9593

9694
def create_dir():
97-
runs = ([x[0] for x in os.walk("../results/logs")])
95+
runs = ([x[0] for x in os.walk("results/logs")])
9896
runs = [x for x in runs if "run" in x]
9997
runs = list(map(int, re.findall(r'\d+', "".join(runs))))
10098
runs.sort()
10199
if len(runs) == 0:
102-
return "../results/logs/run1"
100+
return "results/logs/run1"
103101

104102
dir_idx = runs[-1] + 1
105103

106-
dir = "../results/logs/run" + str(dir_idx)
104+
dir = "results/logs/run" + str(dir_idx)
107105

108106
if not os.path.exists(dir):
109107
os.makedirs(dir)

predict.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import argparse
3+
import spacy
4+
import numpy as np
5+
from keras.models import load_model
6+
from dataset.data_processor import numericalize
7+
from utils.serialization import load_object
8+
from constants import NO_ENTITY_TOKEN, MAX_LEN_CHAR
9+
10+
11+
def parse_args():
12+
parser = argparse.ArgumentParser(description='Script for using NER model')
13+
parser.add_argument('-p', '--path', help='Path to model and vocabulary directory.')
14+
15+
args = parser.parse_args()
16+
# add path separator (/) at the end if needed
17+
args.path = args.path if args.path[-1] == os.path.sep else args.path + os.path.sep
18+
19+
return args
20+
21+
22+
def main():
23+
args = parse_args()
24+
25+
text_vocab = load_object(args.path + 'text_vocab')
26+
pos_vocab = load_object(args.path + 'pos_vocab')
27+
char_vocab = load_object(args.path + 'char_vocab')
28+
labels_vocab = load_object(args.path + 'labels_vocab')
29+
model = load_model(args.path + 'ner_model')
30+
nlp = spacy.load('en')
31+
32+
while True:
33+
user_input = input('Input sentence: ').strip()
34+
if not user_input:
35+
continue
36+
if user_input == 'end':
37+
break
38+
39+
# tokenize user input
40+
doc = nlp(user_input)
41+
text = [token.text for token in doc]
42+
pos = [token.tag_ for token in doc]
43+
chars = numericalize(char_vocab, [[c for c in token.text] for token in doc], NO_ENTITY_TOKEN, maxlen=MAX_LEN_CHAR)
44+
chars = np.array(chars)[np.newaxis, :, :]
45+
46+
print(chars)
47+
48+
# get model output
49+
# pad token is irrelevant here beacuse we are numericalizing just one sentence (it won't be padded)
50+
text = np.array(numericalize(text_vocab, [text], NO_ENTITY_TOKEN))
51+
pos = np.array(numericalize(pos_vocab, [pos], NO_ENTITY_TOKEN))
52+
53+
out = model.predict([text, pos, chars]).squeeze()
54+
predicted_labels = [labels_vocab.itos[label_idx] for label_idx in np.argmax(out, axis=1).tolist()]
55+
56+
# print result
57+
for token, label in zip([token.text for token in doc], predicted_labels):
58+
print("%s %s" % (token, label))
59+
print()
60+
61+
62+
if __name__ == '__main__':
63+
main()

root/main.py

-41
This file was deleted.

root/test_model.py renamed to test_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from keras.models import load_model
66
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
77

8-
from root.constants import MAX_LEN
8+
from constants import MAX_LEN
99
from utils.classification_report import classification_report
1010
from utils.plot_confusion_matrix_util import plot_confusion_matrix
1111

1212

13-
def test_model(test, text_vocab, labels_vocab):
14-
model = load_model('../models/ner_model')
13+
def test_model(model_path, test, text_vocab, labels_vocab):
14+
model = load_model(model_path + 'ner_model')
1515

1616
predicted_values = np.argmax(model.predict([test.X, test.pos, np.array(test.characters).reshape((len(test.characters), MAX_LEN, 10))]),
1717
axis=-1)
@@ -41,13 +41,13 @@ def test_model(test, text_vocab, labels_vocab):
4141
print(report)
4242

4343
# plot_classification_report(report)
44-
# plt.savefig('../results/classification_report.png', dpi=200, format='png', bbox_inches='tight')
44+
# plt.savefig('results/classification_report.png', dpi=200, format='png', bbox_inches='tight')
4545
# plt.close()
4646

4747
# Confusion Matrix
4848
cnf_matrix = confusion_matrix(true_values, predicted_values)
4949
np.set_printoptions(precision=2)
5050
# TODO fix classes
5151
plot_confusion_matrix(cnf_matrix, classes=list(labels_vocab.stoi.keys()), normalize=True, title='Normalized confusion matrix')
52-
plt.savefig('../results/confusion_matrix.png', dpi=200, format='png', bbox_inches='tight')
52+
plt.savefig('results/confusion_matrix.png', dpi=200, format='png', bbox_inches='tight')
5353
plt.close()

train.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import matplotlib.pyplot as plt
2+
3+
from embedding.glove import get_pretrained_glove
4+
from dataset.api import load_dataset
5+
from model import NeuralNetwork
6+
from test_model import test_model
7+
from datetime import datetime
8+
from utils.serialization import save_object
9+
10+
text_vocab, labels_vocab, pos_vocab, character_vocab, train, val, test = load_dataset()
11+
12+
num_words = len(text_vocab.itos)
13+
num_entities = len(labels_vocab.itos)
14+
num_pos = len(pos_vocab.itos)
15+
num_chars = len(character_vocab.itos)
16+
17+
# save vocabulary
18+
save_path = 'models/' + datetime.now().strftime("%Y-%m-%d-%H:%M") + '/'
19+
save_object(text_vocab, save_path + 'text_vocab')
20+
save_object(pos_vocab, save_path + 'pos_vocab')
21+
save_object(character_vocab, save_path + 'char_vocab')
22+
save_object(labels_vocab, save_path + 'labels_vocab')
23+
24+
nn = NeuralNetwork(save_path, num_words, num_entities, num_pos, num_chars, train, test, val)
25+
26+
model, history = nn.train(epochs=3, embedding=get_pretrained_glove(num_words, text_vocab))
27+
28+
print(history.history.keys())
29+
30+
test_model(save_path, test, text_vocab, labels_vocab)
31+
32+
# Plot accuracy
33+
plt.plot(history.history['acc'])
34+
plt.plot(history.history['val_acc'])
35+
plt.title('Model Accuracy')
36+
plt.ylabel('Accuracy')
37+
plt.xlabel('Epoch')
38+
plt.legend(['train', 'validation'], loc='lower right')
39+
plt.savefig('results/model_accuracy.png', dpi=200, format='png', bbox_inches='tight')
40+
plt.close()
41+
42+
# Plot loss
43+
plt.plot(history.history['loss'])
44+
plt.plot(history.history['val_loss'])
45+
plt.title('Model loss')
46+
plt.ylabel('Loss')
47+
plt.xlabel('Epoch')
48+
plt.legend(['train', 'validation'], loc='upper right')
49+
plt.savefig('results/model_loss.png', dpi=200, format='png', bbox_inches='tight')
50+
plt.close()

utils/classification_report.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def classification_report(y_true, y_pred, labels=None, target_names=None,
2424
digits : int
2525
Number of digits for formatting output floating point values
2626
average : string, ['weighted' (default), 'binary', 'micro', 'macro']
27-
Determines the type of averaging performed on the data, after reporting the individual results per class:
27+
Determines the type of averaging performed on the dataset, after reporting the individual results per class:
2828
``'binary'``:
2929
Only report results for the class specified by ``pos_label``.
3030
This is applicable only if targets (``y_{true,pred}``) are binary.

utils/serialization.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
import pickle
3+
4+
5+
def ensure_dir_exists(path):
6+
if not os.path.isdir(path):
7+
os.makedirs(path)
8+
9+
10+
def save_object(obj, path):
11+
ensure_dir_exists(os.path.dirname(path))
12+
with open(path, 'wb') as fd:
13+
pickle.dump(obj, fd)
14+
15+
16+
def load_object(path):
17+
with open(path, 'rb') as fd:
18+
obj = pickle.load(fd)
19+
return obj

0 commit comments

Comments
 (0)