Skip to content

Commit e6ff002

Browse files
refactor: naming & filtering
1 parent 8eecb7f commit e6ff002

File tree

2 files changed

+17
-31
lines changed

2 files changed

+17
-31
lines changed

main.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
from model.network import LeNet5
22
from saliency.vanilla_gradient import save_vanilla_gradient
33
from model.data import mnist_train_test_sets
4+
import numpy as np
45

56
# Get MNIST dataset, preprocessed
67
train_images, train_labels, test_images, test_labels = mnist_train_test_sets()
78

89
# Load net and 98% acc weights
910
net = LeNet5(weights_path="15epoch_weights.pkl")
1011

11-
# Uncomment if you want to train or test from scratch
12+
# Uncomment if you want to train or test
1213
# net.train(training_data=train_images, training_labels=train_labels,
1314
# batch_size=32, epochs=3, weights_path='weights.pkl')
1415
# net.test(test_images, test_labels)
1516

1617
# Uncomment if you want to filter by class
17-
# target_class = 5
18-
# target_indexes = [i for i in range(len(labels))
19-
# if np.argmax(labels[i]) == target_class]
20-
# target_images = [data[index] for index in target_indexes]
21-
# target_labels = [labels[index] for index in target_indexes]
18+
# target_image_class = 7
19+
# target_image_indexes = [i for i in range(len(test_labels))
20+
# if np.argmax(test_labels[i]) == target_image_class]
21+
# target_images = [test_images[index] for index in target_image_indexes]
22+
# target_labels = [test_labels[index] for index in target_image_indexes]
2223

2324
# Generate saliency maps for the first 10 images
2425
target_images = train_images[:10]

saliency/vanilla_gradient.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def save_vanilla_gradient(network, data, labels):
1212

1313
# Create a saliency map for each data point
1414
for i, image in enumerate(data):
15-
# Put input into layers
15+
# Forward pass on image
1616
output = image
1717
for l in range(len(network.layers)):
1818
output = network.layers[l].forward(output)
@@ -23,9 +23,9 @@ def save_vanilla_gradient(network, data, labels):
2323
for l in range(len(network.layers)-1, -1, -1):
2424
dout = network.layers[l].backward(dy)
2525
dy = dout
26+
raw_saliency_map = dout
2627

2728
# Process saliency map
28-
raw_saliency_map = dout
2929
trimmed_saliency_map = trim_map(raw_saliency_map)
3030
saliency_map = normalize_array(trimmed_saliency_map)
3131

@@ -37,7 +37,7 @@ def save_vanilla_gradient(network, data, labels):
3737
# Export saliency map renderings
3838
filename = "index-{0}_class-{1}_vanilla".format(
3939
str(i), str(np.argmax(label_one_hot)))
40-
save_gradient_overlay_images(image, saliency_map, filename)
40+
save_gradient_images(image, saliency_map, filename)
4141

4242
print("Saved Vanilla Gradient image to results folder")
4343

@@ -63,21 +63,6 @@ def normalize_array(arr):
6363
return arr
6464

6565

66-
def save_gradient_images(gradient, file_name):
67-
"""
68-
Exports the original gradient image
69-
70-
Args:
71-
gradient (np arr): Numpy array of the gradient with shape (3, 224, 224)
72-
file_name (str): File name to be exported
73-
"""
74-
# Normalize
75-
gradient = normalize_array(gradient)
76-
# Save image
77-
path_to_file = os.path.join(RESULTS_FOLDER, file_name + '.jpg')
78-
save_image(gradient, path_to_file)
79-
80-
8166
def save_image(im, path):
8267
"""
8368
Saves a numpy matrix or PIL image as an image
@@ -120,17 +105,17 @@ def format_np_output(np_arr):
120105
return np_arr
121106

122107

123-
def save_gradient_overlay_images(org_img, saliency_map, file_name):
108+
def save_gradient_images(org_img, saliency_map, file_name):
124109
"""
125110
Saves saliency map and overlay on the original image
126111
127112
Args:
128113
org_img (PIL img): Original image
129-
activation_map (numpy arr): Activation map (grayscale) 0-255
114+
saliency_map (numpy arr): Saliency map (grayscale) 0-255
130115
file_name (str): File name of the exported image
131116
"""
132117

133-
# Grayscale activation map
118+
# Grayscale saliency map
134119
heatmap, heatmap_on_image = apply_colormap_on_image(
135120
org_img, saliency_map, 'RdBu')
136121

@@ -149,20 +134,20 @@ def save_gradient_overlay_images(org_img, saliency_map, file_name):
149134

150135
# Save grayscale heatmap
151136
# path_to_file = os.path.join(RESULTS_FOLDER, file_name+'_grayscale.png')
152-
# save_image(activation_map, path_to_file)
137+
# save_image(saliency_map, path_to_file)
153138

154139

155-
def apply_colormap_on_image(org_im, activation, colormap_name):
140+
def apply_colormap_on_image(org_im, saliency_map, colormap_name):
156141
"""
157142
Apply heatmap on image
158143
Args:
159144
org_img (PIL img): Original image
160-
activation_map (numpy arr): Activation map (grayscale) 0-255
145+
saliency_map (numpy arr): Saliency map (grayscale) 0-255
161146
colormap_name (str): Name of the colormap
162147
"""
163148
# Get colormap
164149
color_map = mpl_color_map.get_cmap(colormap_name)
165-
no_trans_heatmap = color_map(activation)
150+
no_trans_heatmap = color_map(saliency_map)
166151
# Change alpha channel in colormap to make sure original image is displayed
167152
heatmap = copy.copy(no_trans_heatmap)
168153
heatmap[:, :, 3] = 0.5

0 commit comments

Comments
 (0)