Skip to content

Commit 90d5422

Browse files
File dialog support added with PR #3
1 parent bad4f9e commit 90d5422

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

classify.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import tkinter as tk
55
from tkinter import filedialog
6+
67
# Disable tensorflow compilation warnings
78
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
89
import tensorflow as tf
@@ -14,31 +15,32 @@
1415

1516
image_path = filedialog.askopenfilename()
1617

17-
18-
# Read the image_data
19-
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
20-
21-
# Loads label file, strips off carriage return
22-
label_lines = [line.rstrip() for line
23-
in tf.gfile.GFile("tf_files/retrained_labels.txt")]
24-
25-
# Unpersists graph from file
26-
with tf.gfile.FastGFile("tf_files/retrained_graph.pb", 'rb') as f:
27-
graph_def = tf.GraphDef()
28-
graph_def.ParseFromString(f.read())
29-
_ = tf.import_graph_def(graph_def, name='')
30-
31-
with tf.Session() as sess:
32-
# Feed the image_data as input to the graph and get first prediction
33-
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
34-
35-
predictions = sess.run(softmax_tensor, \
36-
{'DecodeJpeg/contents:0': image_data})
37-
38-
# Sort to show labels of first prediction in order of confidence
39-
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
18+
if image_path:
4019

41-
for node_id in top_k:
42-
human_string = label_lines[node_id]
43-
score = predictions[0][node_id]
44-
print('%s (score = %.5f)' % (human_string, score))
20+
# Read the image_data
21+
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
22+
23+
# Loads label file, strips off carriage return
24+
label_lines = [line.rstrip() for line
25+
in tf.gfile.GFile("tf_files/retrained_labels.txt")]
26+
27+
# Unpersists graph from file
28+
with tf.gfile.FastGFile("tf_files/retrained_graph.pb", 'rb') as f:
29+
graph_def = tf.GraphDef()
30+
graph_def.ParseFromString(f.read())
31+
_ = tf.import_graph_def(graph_def, name='')
32+
33+
with tf.Session() as sess:
34+
# Feed the image_data as input to the graph and get first prediction
35+
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
36+
37+
predictions = sess.run(softmax_tensor, \
38+
{'DecodeJpeg/contents:0': image_data})
39+
40+
# Sort to show labels of first prediction in order of confidence
41+
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
42+
43+
for node_id in top_k:
44+
human_string = label_lines[node_id]
45+
score = predictions[0][node_id]
46+
print('%s (score = %.5f)' % (human_string, score))

0 commit comments

Comments
 (0)