Skip to content

Commit 6f15f33

Browse files
Replace tensorflow.python.keras with keras. tensorflow.python.keras is an old copy and is deprecated.
PiperOrigin-RevId: 485364304
1 parent 0e08dea commit 6f15f33

25 files changed

+12
-91
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import tensorflow as tf
2323

24-
from tensorflow.python.keras import keras_parameterized
24+
from keras import keras_parameterized
2525
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2626
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2727
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

-34
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from absl.testing import parameterized
2222
import tensorflow as tf
2323

24-
from tensorflow.python.keras import keras_parameterized
2524
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2625
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2726
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
@@ -162,15 +161,13 @@ def _count_clustered_layers(self, model):
162161
count += 1
163162
return count
164163

165-
@keras_parameterized.run_all_keras_modes
166164
def testClusterKerasClusterableLayer(self):
167165
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly."""
168166
wrapped_layer = self._build_clustered_layer_model(
169167
self.keras_clusterable_layer)
170168

171169
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
172170

173-
@keras_parameterized.run_all_keras_modes
174171
def testClusterKerasClusterableLayerWithSparsityPreservation(self):
175172
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly when sparsity preservation is enabled."""
176173
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -180,7 +177,6 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self):
180177

181178
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
182179

183-
@keras_parameterized.run_all_keras_modes
184180
def testClusterKerasNonClusterableLayer(self):
185181
"""Verifies that a built-in keras layer not marked as clusterable is not being clustered."""
186182
wrapped_layer = self._build_clustered_layer_model(
@@ -190,7 +186,6 @@ def testClusterKerasNonClusterableLayer(self):
190186
wrapped_layer)
191187
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
192188

193-
@keras_parameterized.run_all_keras_modes
194189
def testDepthwiseConv2DLayerNonClusterable(self):
195190
"""Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196191
wrapped_layer = self._build_clustered_layer_model(
@@ -200,7 +195,6 @@ def testDepthwiseConv2DLayerNonClusterable(self):
200195
wrapped_layer)
201196
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
202197

203-
@keras_parameterized.run_all_keras_modes
204198
def testDenseLayer(self):
205199
"""Verifies that we can cluster a Dense layer."""
206200
input_shape = (28, 1)
@@ -214,7 +208,6 @@ def testDenseLayer(self):
214208
self.assertEqual([1, 10],
215209
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
216210

217-
@keras_parameterized.run_all_keras_modes
218211
def testConv1DLayer(self):
219212
"""Verifies that we can cluster a Conv1D layer."""
220213
input_shape = (28, 1)
@@ -227,7 +220,6 @@ def testConv1DLayer(self):
227220
self.assertEqual([5, 1, 3],
228221
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
229222

230-
@keras_parameterized.run_all_keras_modes
231223
def testConv1DTransposeLayer(self):
232224
"""Verifies that we can cluster a Conv1DTranspose layer."""
233225
input_shape = (28, 1)
@@ -240,7 +232,6 @@ def testConv1DTransposeLayer(self):
240232
self.assertEqual([5, 3, 1],
241233
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
242234

243-
@keras_parameterized.run_all_keras_modes
244235
def testConv2DLayer(self):
245236
"""Verifies that we can cluster a Conv2D layer."""
246237
input_shape = (28, 28, 1)
@@ -253,7 +244,6 @@ def testConv2DLayer(self):
253244
self.assertEqual([4, 5, 1, 3],
254245
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
255246

256-
@keras_parameterized.run_all_keras_modes
257247
def testConv2DTransposeLayer(self):
258248
"""Verifies that we can cluster a Conv2DTranspose layer."""
259249
input_shape = (28, 28, 1)
@@ -266,7 +256,6 @@ def testConv2DTransposeLayer(self):
266256
self.assertEqual([4, 5, 3, 1],
267257
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
268258

269-
@keras_parameterized.run_all_keras_modes
270259
def testConv3DLayer(self):
271260
"""Verifies that we can cluster a Conv3D layer."""
272261
input_shape = (28, 28, 28, 1)
@@ -287,7 +276,6 @@ def testClusterKerasUnsupportedLayer(self):
287276
with self.assertRaises(ValueError):
288277
cluster.cluster_weights(keras_unsupported_layer, **self.params)
289278

290-
@keras_parameterized.run_all_keras_modes
291279
def testClusterCustomClusterableLayer(self):
292280
"""Verifies that a custom clusterable layer is being clustered correctly."""
293281
wrapped_layer = self._build_clustered_layer_model(
@@ -297,7 +285,6 @@ def testClusterCustomClusterableLayer(self):
297285
self.assertEqual([('kernel', wrapped_layer.layer.kernel)],
298286
wrapped_layer.layer.get_clusterable_weights())
299287

300-
@keras_parameterized.run_all_keras_modes
301288
def testClusterCustomClusterableLayerWithSparsityPreservation(self):
302289
"""Verifies that a custom clusterable layer is being clustered correctly when sparsity preservation is enabled."""
303290
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -424,7 +411,6 @@ def testStripClusteringSequentialModelWithBiasConstraint(self):
424411
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
425412
stripped_model.save(keras_file, save_traces=True)
426413

427-
@keras_parameterized.run_all_keras_modes
428414
def testClusterSequentialModelSelectively(self):
429415
clustered_model = keras.Sequential()
430416
clustered_model.add(
@@ -437,7 +423,6 @@ def testClusterSequentialModelSelectively(self):
437423
self.assertNotIsInstance(clustered_model.layers[1],
438424
cluster_wrapper.ClusterWeights)
439425

440-
@keras_parameterized.run_all_keras_modes
441426
def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
442427
"""Verifies that layers within a sequential model can be clustered selectively when sparsity preservation is enabled."""
443428
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -454,7 +439,6 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
454439
self.assertNotIsInstance(clustered_model.layers[1],
455440
cluster_wrapper.ClusterWeights)
456441

457-
@keras_parameterized.run_all_keras_modes
458442
def testClusterFunctionalModelSelectively(self):
459443
"""Verifies that layers within a functional model can be clustered selectively."""
460444
i1 = keras.Input(shape=(10,))
@@ -469,7 +453,6 @@ def testClusterFunctionalModelSelectively(self):
469453
self.assertNotIsInstance(clustered_model.layers[3],
470454
cluster_wrapper.ClusterWeights)
471455

472-
@keras_parameterized.run_all_keras_modes
473456
def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
474457
"""Verifies that layers within a functional model can be clustered selectively when sparsity preservation is enabled."""
475458
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -486,7 +469,6 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
486469
self.assertNotIsInstance(clustered_model.layers[3],
487470
cluster_wrapper.ClusterWeights)
488471

489-
@keras_parameterized.run_all_keras_modes
490472
def testClusterModelValidLayersSuccessful(self):
491473
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered."""
492474
model = keras.Sequential([
@@ -500,7 +482,6 @@ def testClusterModelValidLayersSuccessful(self):
500482
for layer, clustered_layer in zip(model.layers, clustered_model.layers):
501483
self._validate_clustered_layer(layer, clustered_layer)
502484

503-
@keras_parameterized.run_all_keras_modes
504485
def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self):
505486
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered when sparsity preservation is enabled."""
506487
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -540,7 +521,6 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
540521
self.custom_clusterable_layer, custom_non_clusterable_layer
541522
]), **self.params)
542523

543-
@keras_parameterized.run_all_keras_modes
544524
def testClusterModelDoesNotWrapAlreadyWrappedLayer(self):
545525
"""Verifies that clustering a model that contains an already clustered layer does not result in wrapping the clustered layer into another cluster_wrapper."""
546526
model = keras.Sequential([
@@ -579,7 +559,6 @@ def testClusterSequentialModelNoInput(self):
579559
clustered_model = cluster.cluster_weights(model, **self.params)
580560
self.assertEqual(self._count_clustered_layers(clustered_model), 2)
581561

582-
@keras_parameterized.run_all_keras_modes
583562
def testClusterSequentialModelWithInput(self):
584563
"""Verifies that a sequential model with an input layer is being clustered correctly."""
585564
# With InputLayer
@@ -607,7 +586,6 @@ def testClusterSequentialModelPreservesBuiltStateNoInput(self):
607586
json.loads(clustered_model.to_json()))
608587
self.assertEqual(loaded_model.built, False)
609588

610-
@keras_parameterized.run_all_keras_modes
611589
def testClusterSequentialModelPreservesBuiltStateWithInput(self):
612590
"""Verifies that clustering a sequential model with an input layer preserves the built state of the model."""
613591
# With InputLayer
@@ -625,7 +603,6 @@ def testClusterSequentialModelPreservesBuiltStateWithInput(self):
625603
json.loads(clustered_model.to_json()))
626604
self.assertEqual(loaded_model.built, True)
627605

628-
@keras_parameterized.run_all_keras_modes
629606
def testClusterFunctionalModelPreservesBuiltState(self):
630607
"""Verifies that clustering a functional model preserves the built state of the model."""
631608
i1 = keras.Input(shape=(10,))
@@ -644,7 +621,6 @@ def testClusterFunctionalModelPreservesBuiltState(self):
644621
json.loads(clustered_model.to_json()))
645622
self.assertEqual(loaded_model.built, True)
646623

647-
@keras_parameterized.run_all_keras_modes
648624
def testClusterFunctionalModel(self):
649625
"""Verifies that a functional model is being clustered correctly."""
650626
i1 = keras.Input(shape=(10,))
@@ -656,7 +632,6 @@ def testClusterFunctionalModel(self):
656632
clustered_model = cluster.cluster_weights(model, **self.params)
657633
self.assertEqual(self._count_clustered_layers(clustered_model), 3)
658634

659-
@keras_parameterized.run_all_keras_modes
660635
def testClusterFunctionalModelWithLayerReused(self):
661636
"""Verifies that a layer reused within a functional model multiple times is only being clustered once."""
662637
# The model reuses the Dense() layer. Make sure it's only clustered once.
@@ -668,22 +643,19 @@ def testClusterFunctionalModelWithLayerReused(self):
668643
clustered_model = cluster.cluster_weights(model, **self.params)
669644
self.assertEqual(self._count_clustered_layers(clustered_model), 1)
670645

671-
@keras_parameterized.run_all_keras_modes
672646
def testClusterSubclassModel(self):
673647
"""Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception."""
674648
model = TestModel()
675649
with self.assertRaises(ValueError):
676650
_ = cluster.cluster_weights(model, **self.params)
677651

678-
@keras_parameterized.run_all_keras_modes
679652
def testClusterSubclassModelAsSubmodel(self):
680653
"""Verifies that attempting to cluster a model with submodel that is a subclass throws an exception."""
681654
model_subclass = TestModel()
682655
model = keras.Sequential([layers.Dense(10), model_subclass])
683656
with self.assertRaisesRegex(ValueError, 'Subclassed models.*'):
684657
_ = cluster.cluster_weights(model, **self.params)
685658

686-
@keras_parameterized.run_all_keras_modes
687659
def testStripClusteringSequentialModel(self):
688660
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
689661
model = keras.Sequential([
@@ -697,7 +669,6 @@ def testStripClusteringSequentialModel(self):
697669
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
698670
self.assertEqual(model.get_config(), stripped_model.get_config())
699671

700-
@keras_parameterized.run_all_keras_modes
701672
def testClusterStrippingFunctionalModel(self):
702673
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
703674
i1 = keras.Input(shape=(10,))
@@ -713,7 +684,6 @@ def testClusterStrippingFunctionalModel(self):
713684
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
714685
self.assertEqual(model.get_config(), stripped_model.get_config())
715686

716-
@keras_parameterized.run_all_keras_modes
717687
def testClusterWeightsStrippedWeights(self):
718688
"""Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights."""
719689
i1 = keras.Input(shape=(10,))
@@ -728,7 +698,6 @@ def testClusterWeightsStrippedWeights(self):
728698
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
729699
self.assertLen(stripped_model.get_weights(), cluster_weight_length)
730700

731-
@keras_parameterized.run_all_keras_modes
732701
def testStrippedKernel(self):
733702
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734703
i1 = keras.Input(shape=(1, 1, 1))
@@ -746,7 +715,6 @@ def testStrippedKernel(self):
746715
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
747716
self.assertIn(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights)
748717

749-
@keras_parameterized.run_all_keras_modes
750718
def testStripSelectivelyClusteredFunctionalModel(self):
751719
"""Verifies that invoking strip_clustering() on a selectively clustered functional model strips the clustering wrappers from the clustered layers."""
752720
i1 = keras.Input(shape=(10,))
@@ -761,7 +729,6 @@ def testStripSelectivelyClusteredFunctionalModel(self):
761729
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
762730
self.assertIsInstance(stripped_model.layers[2], layers.Dense)
763731

764-
@keras_parameterized.run_all_keras_modes
765732
def testStripSelectivelyClusteredSequentialModel(self):
766733
"""Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers."""
767734
clustered_model = keras.Sequential([
@@ -775,7 +742,6 @@ def testStripSelectivelyClusteredSequentialModel(self):
775742
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
776743
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
777744

778-
@keras_parameterized.run_all_keras_modes
779745
def testStripClusteringAndSetOriginalWeightsBack(self):
780746
"""Verifies that we can set_weights onto the stripped model."""
781747
model = keras.Sequential([

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818
import tensorflow as tf
1919

20-
from tensorflow.python.keras import keras_parameterized
2120
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2221
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2322
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
@@ -30,7 +29,6 @@
3029
layers = tf.keras.layers
3130

3231

33-
@keras_parameterized.run_all_keras_modes
3432
class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase):
3533

3634
def setUp(self):

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import tensorflow as tf
2020

21-
from tensorflow.python.keras import keras_parameterized
2221

2322
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
2423
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
@@ -29,7 +28,6 @@
2928
layers = tf.keras.layers
3029

3130

32-
@keras_parameterized.run_all_keras_modes
3331
class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase,
3432
parameterized.TestCase):
3533

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve/prune_preserve_quantize_registry_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import tensorflow as tf
1919

20-
from tensorflow.python.keras import keras_parameterized
2120
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
2221
from tensorflow_model_optimization.python.core.quantization.keras.collaborative_optimizations.prune_preserve import (
2322
prune_preserve_quantize_registry,)
@@ -28,7 +27,6 @@
2827
layers = tf.keras.layers
2928

3029

31-
@keras_parameterized.run_all_keras_modes
3230
class PrunePreserveQuantizeRegistryTest(tf.test.TestCase,
3331
parameterized.TestCase):
3432

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import numpy as np
2525
import tensorflow as tf
2626

27-
from tensorflow.python.keras import keras_parameterized
2827
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2928
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
3029

@@ -73,7 +72,6 @@ def _assert_kernel_equality(self, a, b):
7372
self.assertAllEqual(a.numpy(), b.numpy())
7473

7574

76-
@keras_parameterized.run_all_keras_modes
7775
class QuantizeRegistryTest(
7876
tf.test.TestCase, parameterized.TestCase, _TestHelper):
7977

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@
2222

2323
import tensorflow as tf
2424

25-
from tensorflow.python.keras import keras_parameterized
2625
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
2726

2827
Default8BitConvWeightsQuantizer = default_8bit_quantizers.Default8BitConvWeightsQuantizer
2928

3029
keras = tf.keras
3130

3231

33-
@keras_parameterized.run_all_keras_modes
3432
class Default8BitConvWeightsQuantizerTest(tf.test.TestCase,
3533
parameterized.TestCase):
3634

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
import collections
1818
import inspect
1919

20+
from keras import backend
2021
import numpy as np
2122
import tensorflow as tf
2223

23-
from tensorflow.python.keras import backend
24-
2524
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2625
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2726
from tensorflow_model_optimization.python.core.quantization.keras import quantizers

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
import numpy as np
2323
import tensorflow as tf
2424

25-
from tensorflow.python.keras import keras_parameterized
2625
from tensorflow_model_optimization.python.core.quantization.keras import quantize
2726
from tensorflow_model_optimization.python.core.quantization.keras import utils
2827

2928

30-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
3129
class QuantizeNumericalTest(tf.test.TestCase, parameterized.TestCase):
3230

3331
def _batch(self, dims, batch_size):

0 commit comments

Comments
 (0)