Skip to content

Handle y_true as class targets for GradCAM #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions tf_explain/core/grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def explain(
self,
validation_data,
model,
class_index,
class_index=None,
layer_name=None,
use_guided_grads=True,
colormap=cv2.COLORMAP_VIRIDIS,
Expand All @@ -38,21 +38,27 @@ def explain(
class_index (int): Index of targeted class
layer_name (str): Targeted layer for GradCAM. If no layer is provided, it is
automatically infered from the model architecture.
use_guided_grads (boolean): Whether to use guided grads or raw gradients
colormap (int): OpenCV Colormap to use for heatmap visualization
image_weight (float): An optional `float` value in range [0,1] indicating the weight of
the input image to be overlaying the calculated attribution maps. Defaults to `0.7`.
use_guided_grads (boolean): Whether to use guided grads or raw gradients

Returns:
numpy.ndarray: Grid of all the GradCAM
"""
images, _ = validation_data
images, targets = validation_data
is_target_categorical = len(np.array(targets, dtype="uint8").shape) == 2
if is_target_categorical:
targets = list(np.argmax(targets, axis=1))

if layer_name is None:
layer_name = self.infer_grad_cam_target_layer(model)

outputs, grads = GradCAM.get_gradients_and_filters(
model, images, layer_name, class_index, use_guided_grads
if class_index is not None:
targets = [class_index] * len(images)

outputs, grads = GradCAM.get_gradients_and_filters_from_targets(
model, images, layer_name, targets, use_guided_grads
)

cams = GradCAM.generate_ponderated_output(outputs, grads)
Expand Down Expand Up @@ -92,8 +98,8 @@ def infer_grad_cam_target_layer(model):

@staticmethod
@tf.function
def get_gradients_and_filters(
model, images, layer_name, class_index, use_guided_grads
def get_gradients_and_filters_from_targets(
model, images, layer_name, targets, use_guided_grads
):
"""
Generate guided gradients and convolutional outputs with an inference.
Expand All @@ -102,7 +108,7 @@ def get_gradients_and_filters(
model (tf.keras.Model): tf.keras model to inspect
images (numpy.ndarray): 4D-Tensor with shape (batch_size, H, W, 3)
layer_name (str): Targeted layer for GradCAM
class_index (int): Index of targeted class
targets (List[int]): List of class targets
use_guided_grads (boolean): Whether to use guided grads or raw gradients

Returns:
Expand All @@ -115,8 +121,9 @@ def get_gradients_and_filters(
with tf.GradientTape() as tape:
inputs = tf.cast(images, tf.float32)
conv_outputs, predictions = grad_model(inputs)
loss = predictions[:, class_index]

loss = tf.gather_nd(
predictions, [[index, target] for index, target in enumerate(targets)]
)
grads = tape.gradient(loss, conv_outputs)

if use_guided_grads:
Expand Down