diff --git a/tf_keras/utils/losses_utils.py b/tf_keras/utils/losses_utils.py index 49d5e1cc7..1a7d100f2 100644 --- a/tf_keras/utils/losses_utils.py +++ b/tf_keras/utils/losses_utils.py @@ -195,7 +195,7 @@ def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): y_true_rank = y_true_shape.ndims if (y_true_rank is not None) and (y_pred_rank is not None): # Use static rank for `y_true` and `y_pred`. - if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: + if (y_pred_rank - y_true_rank == 1) or y_pred_shape[-1] == 1: y_true, y_pred = remove_squeezable_dimensions(y_true, y_pred) else: # Use dynamic rank.