From d97daf4b303124fd4e177334bd0b097b5daeb4d0 Mon Sep 17 00:00:00 2001 From: Raphael Meudec Date: Sat, 11 Jul 2020 17:07:27 +0200 Subject: [PATCH] Handle y_true --- tf_explain/core/grad_cam.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tf_explain/core/grad_cam.py b/tf_explain/core/grad_cam.py index 7d384e5..03b1bc9 100644 --- a/tf_explain/core/grad_cam.py +++ b/tf_explain/core/grad_cam.py @@ -22,7 +22,7 @@ def explain( self, validation_data, model, - class_index, + class_index=None, layer_name=None, use_guided_grads=True, colormap=cv2.COLORMAP_VIRIDIS, @@ -38,21 +38,27 @@ def explain( class_index (int): Index of targeted class layer_name (str): Targeted layer for GradCAM. If no layer is provided, it is automatically infered from the model architecture. + use_guided_grads (boolean): Whether to use guided grads or raw gradients colormap (int): OpenCV Colormap to use for heatmap visualization image_weight (float): An optional `float` value in range [0,1] indicating the weight of the input image to be overlaying the calculated attribution maps. Defaults to `0.7`. - use_guided_grads (boolean): Whether to use guided grads or raw gradients Returns: numpy.ndarray: Grid of all the GradCAM """ - images, _ = validation_data + images, targets = validation_data + is_target_categorical = len(np.array(targets, dtype="uint8").shape) == 2 + if is_target_categorical: + targets = list(np.argmax(targets, axis=1)) if layer_name is None: layer_name = self.infer_grad_cam_target_layer(model) - outputs, grads = GradCAM.get_gradients_and_filters( - model, images, layer_name, class_index, use_guided_grads + if class_index is not None: + targets = [class_index] * len(images) + + outputs, grads = GradCAM.get_gradients_and_filters_from_targets( + model, images, layer_name, targets, use_guided_grads ) cams = GradCAM.generate_ponderated_output(outputs, grads) @@ -92,8 +98,8 @@ def infer_grad_cam_target_layer(model): @staticmethod @tf.function - def get_gradients_and_filters( - model, images, layer_name, class_index, use_guided_grads + def get_gradients_and_filters_from_targets( + model, images, layer_name, targets, use_guided_grads ): """ Generate guided gradients and convolutional outputs with an inference. @@ -102,7 +108,7 @@ def get_gradients_and_filters( model (tf.keras.Model): tf.keras model to inspect images (numpy.ndarray): 4D-Tensor with shape (batch_size, H, W, 3) layer_name (str): Targeted layer for GradCAM - class_index (int): Index of targeted class + targets (List[int]): List of class targets use_guided_grads (boolean): Whether to use guided grads or raw gradients Returns: @@ -115,8 +121,9 @@ def get_gradients_and_filters( with tf.GradientTape() as tape: inputs = tf.cast(images, tf.float32) conv_outputs, predictions = grad_model(inputs) - loss = predictions[:, class_index] - + loss = tf.gather_nd( + predictions, [[index, target] for index, target in enumerate(targets)] + ) grads = tape.gradient(loss, conv_outputs) if use_guided_grads: