Skip to content

New function for normalizing gradient maps through min-max technique. #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tf_explain/core/gradients_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class GradientsInputs(VanillaGradients):
"""

@staticmethod
@tf.function
def compute_gradients(images, model, class_index):
"""
Compute gradients ponderated by input values for target class.
Expand Down
18 changes: 11 additions & 7 deletions tf_explain/core/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from tf_explain.utils.display import grid_display
from tf_explain.utils.image import transform_to_normalized_grayscale
from tf_explain.utils.image import transform_to_normalized_grayscale, normalize_min_max
from tf_explain.utils.saver import save_grayscale


Expand All @@ -17,7 +17,7 @@ class IntegratedGradients:
Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)
"""

def explain(self, validation_data, model, class_index, n_steps=10):
def explain(self, validation_data, model, class_index, n_steps=10, norm = "std"):
"""
Compute Integrated Gradients for a specific class index

Expand All @@ -27,6 +27,7 @@ def explain(self, validation_data, model, class_index, n_steps=10):
model (tf.keras.Model): tf.keras model to inspect
class_index (int): Index of targeted class
n_steps (int): Number of steps in the path
norm (str): Normalization technique. Can be chosen from *std* and *min_max*.

Returns:
np.ndarray: Grid of all the integrated gradients
Expand All @@ -41,16 +42,19 @@ def explain(self, validation_data, model, class_index, n_steps=10):
interpolated_images, model, class_index, n_steps
)

grayscale_integrated_gradients = transform_to_normalized_grayscale(
tf.abs(integrated_gradients)
).numpy()
if not norm in ["std", "min_max"]:
raise KeyError("Normalization method can only be chosen from 'std' and 'min_max'.")

grid = grid_display(grayscale_integrated_gradients)
elif norm == "std":
grayscale_integrated_gradients = transform_to_normalized_grayscale(tf.abs(integrated_gradients)).numpy()
grid = grid_display(grayscale_integrated_gradients)

else: # min_max
grid = normalize_min_max(integrated_gradients).numpy()

return grid

@staticmethod
@tf.function
def get_integrated_gradients(interpolated_images, model, class_index, n_steps):
"""
Perform backpropagation to compute integrated gradients.
Expand Down
17 changes: 11 additions & 6 deletions tf_explain/core/smoothgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorflow as tf

from tf_explain.utils.display import grid_display
from tf_explain.utils.image import transform_to_normalized_grayscale
from tf_explain.utils.image import transform_to_normalized_grayscale, normalize_min_max
from tf_explain.utils.saver import save_grayscale


Expand All @@ -18,7 +18,7 @@ class SmoothGrad:
Paper: [SmoothGrad: removing noise by adding noise](https://arxiv.org/abs/1706.03825)
"""

def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0):
def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0, norm = "std"):
"""
Compute SmoothGrad for a specific class index

Expand All @@ -29,6 +29,7 @@ def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0)
class_index (int): Index of targeted class
num_samples (int): Number of noisy samples to generate for each input image
noise (float): Standard deviation for noise normal distribution
norm (str): Normalization technique. Can be chosen from *std* and *min_max*. Defaults to "std".

Returns:
np.ndarray: Grid of all the smoothed gradients
Expand All @@ -41,11 +42,15 @@ def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0)
noisy_images, model, class_index, num_samples
)

grayscale_gradients = transform_to_normalized_grayscale(
tf.abs(smoothed_gradients)
).numpy()
if not norm in ["std", "min_max"]:
raise KeyError("Normalization method can only be chosen from 'std' and 'min_max'.")

grid = grid_display(grayscale_gradients)
elif norm == "std":
grayscale_integrated_gradients = transform_to_normalized_grayscale(tf.abs(smoothed_gradients)).numpy()
grid = grid_display(grayscale_integrated_gradients)

else: # min_max
grid = normalize_min_max(smoothed_gradients).numpy()

return grid

Expand Down
25 changes: 15 additions & 10 deletions tf_explain/core/vanilla_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from tf_explain.utils.display import grid_display
from tf_explain.utils.image import transform_to_normalized_grayscale
from tf_explain.utils.image import transform_to_normalized_grayscale, normalize_min_max
from tf_explain.utils.saver import save_grayscale


Expand Down Expand Up @@ -34,7 +34,7 @@ class VanillaGradients:
Models and Saliency Maps](https://arxiv.org/abs/1312.6034)
"""

def explain(self, validation_data, model, class_index):
def explain(self, validation_data, model, class_index, norm = "std"):
"""
Perform gradients backpropagation for a given input

Expand All @@ -47,12 +47,13 @@ def explain(self, validation_data, model, class_index):
gradient calculation to bypass the final activation and calculate
the gradient of the score instead.
class_index (int): Index of targeted class
norm (str): Normalization technique. Can be chosen from *std* and *min_max*. Defaults to *std*.

Returns:
numpy.ndarray: Grid of all the gradients
"""
score_model = self.get_score_model(model)
return self.explain_score_model(validation_data, score_model, class_index)
return self.explain_score_model(validation_data, score_model, class_index, norm)

def get_score_model(self, model):
"""
Expand Down Expand Up @@ -86,7 +87,7 @@ def _is_activation_layer(self, layer):
"""
return isinstance(layer, ACTIVATION_LAYER_CLASSES)

def explain_score_model(self, validation_data, score_model, class_index):
def explain_score_model(self, validation_data, score_model, class_index, norm):
"""
Perform gradients backpropagation for a given input

Expand All @@ -96,24 +97,28 @@ def explain_score_model(self, validation_data, score_model, class_index):
score_model (tf.keras.Model): tf.keras model to inspect. The last layer
should not have any activation function.
class_index (int): Index of targeted class
norm (str): Normalization technique. Can be chosen from *std* and *min_max*.

Returns:
numpy.ndarray: Grid of all the gradients
"""
images, _ = validation_data

images, _ = validation_data
gradients = self.compute_gradients(images, score_model, class_index)

grayscale_gradients = transform_to_normalized_grayscale(
tf.abs(gradients)
).numpy()
if not norm in ["std", "min_max"]:
raise KeyError("Normalization method can only be chosen from 'std' and 'min_max'.")

elif norm == "std":
grayscale_gradients = transform_to_normalized_grayscale(tf.abs(gradients)).numpy()
grid = grid_display(grayscale_gradients)

grid = grid_display(grayscale_gradients)
else: # min_max
grid = normalize_min_max(gradients).numpy()

return grid

@staticmethod
@tf.function
def compute_gradients(images, model, class_index):
"""
Compute gradients for target class.
Expand Down
21 changes: 21 additions & 0 deletions tf_explain/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ def transform_to_normalized_grayscale(tensor):
)

return normalized_tensor

def normalize_min_max(tensor):
"""
Normalize tensor over RGB axis by subtracting min from maximum absolute values and dividing them by range.

Args:
tf.tensor: 4D-Tensor with shape (batch_size, H, W, 3)

Returns:
tf.Tensor: 2D-Tensor with shape (H, W)
"""

normalized_tensor = tf.math.abs(tensor)
normalized_tensor = tf.math.reduce_max(normalized_tensor, axis=-1) # max along channels

# Normalize to range between 0 and 1
arr_min, arr_max = tf.math.reduce_min(normalized_tensor, axis=None), tf.math.reduce_max(normalized_tensor, axis=None)
normalized_tensor = (normalized_tensor - arr_min) / (arr_max - arr_min + tf.constant(1e-16))
normalized_tensor = tf.cast(255 * normalized_tensor, tf.uint8)[0]

return normalized_tensor