diff --git a/examples/nlp/multi_label_classification.py b/examples/nlp/multi_label_classification.py index 14c944b657..d2de991c6c 100644 --- a/examples/nlp/multi_label_classification.py +++ b/examples/nlp/multi_label_classification.py @@ -30,8 +30,9 @@ ## Imports """ -from tensorflow.keras import layers -from tensorflow import keras +import keras +from keras import layers + import tensorflow as tf from sklearn.model_selection import train_test_split @@ -145,7 +146,7 @@ """ terms = tf.ragged.constant(train_df["terms"].values) -lookup = tf.keras.layers.StringLookup(output_mode="multi_hot") +lookup = layers.StringLookup(output_mode="multi_hot") lookup.adapt(terms) vocab = lookup.get_vocabulary()