Skip to content

Commit 41f4f08

Browse files
committed
Fixed predict script. Some refactoring.
1 parent c323e08 commit 41f4f08

File tree

5 files changed

+20
-9
lines changed

5 files changed

+20
-9
lines changed

dataset/api.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ def load_examples(file_path):
5555

5656
sentence.append(parts[0])
5757

58-
if parts[1] in ['$', '"', '(', ')', "''", '.', ':', ',']:
59-
pos.append('NN')
58+
if parts[1] == '(':
59+
pos.append('-LRB-')
60+
elif parts[1] == ')':
61+
pos.append('-RRB-')
6062
else:
6163
pos.append(parts[1])
6264

dataset/data_processor.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from keras import preprocessing
21
from keras_preprocessing.sequence import pad_sequences
32

43

dataset/vocab.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def build(sentences):
6767

6868
vocab = PosVocab()
6969
vocab._itos = [PAD] + list(unique_pos)
70-
vocab._stoi = defaultdict(lambda: 1)
70+
vocab._stoi = defaultdict(_unk_token_idx)
7171
vocab.stoi.update({k: v for v, k in enumerate(vocab.itos)})
7272

7373
return vocab
@@ -80,7 +80,7 @@ def build(words):
8080
chars = list(map(lambda c: c, " 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,-_()[]{}!?:;#'\"/\\%$`&=*+@^~|"))
8181

8282
vocab._itos = [PAD, UNK] + chars
83-
vocab._stoi = defaultdict(lambda: 2)
83+
vocab._stoi = defaultdict(_unk_token_idx)
8484
vocab.stoi.update({k: v for v, k in enumerate(vocab.itos)})
8585

8686
return vocab

predict.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def main():
2323
args = parse_args()
2424

2525
text_vocab = load_object(args.path + 'text_vocab')
26+
pos_vocab = load_object(args.path + 'pos_vocab')
27+
char_vocab = load_object(args.path + 'char_vocab')
2628
labels_vocab = load_object(args.path + 'labels_vocab')
2729
model = load_model(args.path + 'ner_model')
2830
nlp = spacy.load('en')
@@ -36,15 +38,21 @@ def main():
3638

3739
# tokenize user input
3840
doc = nlp(user_input)
39-
user_input_tokenized = [token.text for token in doc]
41+
text = [token.text for token in doc]
42+
pos = [token.tag_ for token in doc]
43+
chars = numericalize(char_vocab, [[c for c in token.text] for token in doc], NO_ENTITY_TOKEN, maxlen=10)
44+
chars = np.array(chars)[np.newaxis, :, :]
4045

4146
# get model output
42-
model_input = np.array(numericalize(text_vocab, [user_input_tokenized], NO_ENTITY_TOKEN)) # pad token is irrelevant here beacuse we are numericalizing just one sentence (it won't be padded)
43-
out = model.predict(model_input).squeeze()
47+
# pad token is irrelevant here beacuse we are numericalizing just one sentence (it won't be padded)
48+
text = np.array(numericalize(text_vocab, [text], NO_ENTITY_TOKEN))
49+
pos = np.array(numericalize(pos_vocab, [pos], NO_ENTITY_TOKEN))
50+
51+
out = model.predict([text, pos, chars]).squeeze()
4452
predicted_labels = [labels_vocab.itos[label_idx] for label_idx in np.argmax(out, axis=1).tolist()]
4553

4654
# print result
47-
for token, label in zip(user_input_tokenized, predicted_labels):
55+
for token, label in zip([token.text for token in doc], predicted_labels):
4856
print("%s %s" % (token, label))
4957
print()
5058

train.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# save vocabulary
1818
save_path = 'models/' + datetime.now().strftime("%Y-%m-%d-%H:%M") + '/'
1919
save_object(text_vocab, save_path + 'text_vocab')
20+
save_object(pos_vocab, save_path + 'pos_vocab')
21+
save_object(character_vocab, save_path + 'char_vocab')
2022
save_object(labels_vocab, save_path + 'labels_vocab')
2123

2224
nn = NeuralNetwork(save_path, num_words, num_entities, num_pos, num_chars, train, test, val)

0 commit comments

Comments
 (0)