Skip to content

Commit b78cbb9

Browse files
fchollettensorflower-gardener
authored andcommitted
Rewire TensorFlow to rely on tf_keras target.
PiperOrigin-RevId: 559470121
1 parent d16ccaa commit b78cbb9

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2020

2121
try:
22-
from keras.engine import base_layer # pylint: disable=g-import-not-at-top
22+
# OSS case.
23+
import keras # pylint: disable=g-import-not-at-top
24+
if hasattr(keras, 'src'):
25+
# Path as seen in pip packages as of TF/Keras 2.13.
26+
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
27+
else:
28+
from keras.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
2329
except ImportError:
24-
# Path as seen in pip packages as of TF/Keras 2.13.
25-
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top
26-
27-
# TODO(b/139939526): move to public API.
30+
# Internal case.
31+
base_layer = tf._keras_internal.engine.base_layer # pylint: disable=protected-access
2832

2933
layers = tf.keras.layers
3034
layers_compat_v1 = tf.compat.v1.keras.layers

0 commit comments

Comments
 (0)